Skip to content

Commit 4b0f585

Browse files
committed
Rust: Non-symmetrict type propagation for if and match
1 parent b32c5bd commit 4b0f585

File tree

2 files changed

+45
-239
lines changed

2 files changed

+45
-239
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
540540
let.getInitializer() = n2
541541
)
542542
or
543-
n1 = n2.(IfExpr).getABranch()
544-
or
545-
n1 = n2.(MatchExpr).getAnArm().getExpr()
546-
or
547543
exists(LetExpr let |
548544
n1 = let.getScrutinee() and
549545
n2 = let.getPat()
@@ -606,11 +602,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
606602
)
607603
)
608604
or
609-
// an array list expression (`[1, 2, 3]`) has the type of the first (any) element
610-
n1.(ArrayListExpr).getExpr(_) = n2 and
611-
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
612-
prefix2.isEmpty()
613-
or
614605
// an array repeat expression (`[1; 3]`) has the type of the repeat operand
615606
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
616607
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
@@ -635,6 +626,49 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
635626
prefix2.isEmpty()
636627
}
637628

629+
/**
630+
* Holds if `child` is a child of `parent`, and the Rust compiler applies [least
631+
* upper bound (LUB) coercion](1) to infer the type of `parent` from the type of
632+
* `child`.
633+
*
634+
* In this case, we want type information to only flow from `child` to `parent`,
635+
* to avoid (a) either having to model LUB coercions, or (b) risk combinatorial
636+
* explosion in inferred types.
637+
*
638+
* [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound
639+
*/
640+
private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
641+
child = parent.(IfExpr).getABranch() and
642+
prefix.isEmpty()
643+
or
644+
child = parent.(MatchExpr).getAnArm().getExpr() and
645+
prefix.isEmpty()
646+
or
647+
child = parent.(ArrayListExpr).getAnExpr() and
648+
prefix = TypePath::singleton(TArrayTypeParameter())
649+
}
650+
651+
/**
652+
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
653+
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
654+
* `n2`.
655+
*/
656+
private predicate typeEqualityNonSymmetric(
657+
AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2
658+
) {
659+
lubCoercion(n2, n1, prefix2) and
660+
prefix1.isEmpty()
661+
or
662+
exists(AstNode mid, TypePath prefixMid, TypePath suffix |
663+
typeEquality(n1, prefixMid, mid, prefix2) or
664+
typeEquality(mid, prefix2, n1, prefixMid)
665+
|
666+
lubCoercion(mid, n2, suffix) and
667+
not lubCoercion(mid, n1, _) and
668+
prefix1 = prefixMid.append(suffix)
669+
)
670+
}
671+
638672
pragma[nomagic]
639673
private Type inferTypeEquality(AstNode n, TypePath path) {
640674
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
@@ -644,6 +678,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
644678
typeEquality(n, prefix1, n2, prefix2)
645679
or
646680
typeEquality(n2, prefix2, n, prefix1)
681+
or
682+
typeEqualityNonSymmetric(n2, prefix2, n, prefix1)
647683
)
648684
}
649685

0 commit comments

Comments
 (0)