Skip to content

Commit 1f8f5d3

Browse files
committed
Rust: Non-symmetrict type propagation for if and match
1 parent df3c851 commit 1f8f5d3

File tree

2 files changed

+36
-226
lines changed

2 files changed

+36
-226
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
540540
let.getInitializer() = n2
541541
)
542542
or
543-
n1 = n2.(IfExpr).getABranch()
544-
or
545-
n1 = n2.(MatchExpr).getAnArm().getExpr()
546-
or
547543
exists(LetExpr let |
548544
n1 = let.getScrutinee() and
549545
n2 = let.getPat()
@@ -635,6 +631,40 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
635631
prefix2.isEmpty()
636632
}
637633

634+
/**
635+
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
636+
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
637+
* `n2`.
638+
*/
639+
private predicate typeEqualityNonSymmetric(
640+
AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2
641+
) {
642+
prefix1.isEmpty() and
643+
prefix2.isEmpty() and
644+
(
645+
n1 = n2.(IfExpr).getABranch()
646+
or
647+
n1 = n2.(MatchExpr).getAnArm().getExpr()
648+
)
649+
or
650+
exists(AstNode mid |
651+
typeEquality(n1, prefix1, mid, prefix2) or
652+
typeEquality(mid, prefix2, n1, prefix1)
653+
|
654+
mid =
655+
any(IfExpr ie |
656+
n2 = ie.getABranch() and
657+
not n1 = ie.getABranch()
658+
)
659+
or
660+
mid =
661+
any(MatchExpr me |
662+
n2 = me.getAnArm().getExpr() and
663+
not n1 = me.getAnArm().getExpr()
664+
)
665+
)
666+
}
667+
638668
pragma[nomagic]
639669
private Type inferTypeEquality(AstNode n, TypePath path) {
640670
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
@@ -644,6 +674,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
644674
typeEquality(n, prefix1, n2, prefix2)
645675
or
646676
typeEquality(n2, prefix2, n, prefix1)
677+
or
678+
typeEqualityNonSymmetric(n2, prefix2, n, prefix1)
647679
)
648680
}
649681

0 commit comments

Comments
 (0)