Skip to content

Commit 4914a3f

Browse files
committed
Rust: Non-symmetric type propagation for lub coercions
1 parent 48ff309 commit 4914a3f

File tree

2 files changed

+82
-235
lines changed

2 files changed

+82
-235
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,17 @@ private Struct getRangeType(RangeExpr re) {
524524
result instanceof RangeToInclusiveStruct
525525
}
526526

527+
private predicate bodyReturns(Expr body, Expr e) {
528+
exists(ReturnExpr re, Callable c |
529+
e = re.getExpr() and
530+
c = re.getEnclosingCallable()
531+
|
532+
body = c.(Function).getBody()
533+
or
534+
body = c.(ClosureExpr).getBody()
535+
)
536+
}
537+
527538
/**
528539
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
529540
* of `n2` at `prefix2` and type information should propagate in both directions
@@ -540,9 +551,11 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
540551
let.getInitializer() = n2
541552
)
542553
or
543-
n1 = n2.(IfExpr).getABranch()
544-
or
545-
n1 = n2.(MatchExpr).getAnArm().getExpr()
554+
n2 =
555+
any(MatchExpr me |
556+
n1 = me.getAnArm().getExpr() and
557+
me.getNumberOfArms() = 1
558+
)
546559
or
547560
exists(LetExpr let |
548561
n1 = let.getScrutinee() and
@@ -573,6 +586,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
573586
n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion()
574587
or
575588
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
589+
or
590+
bodyReturns(n1, n2) and
591+
strictcount(Expr e | bodyReturns(n1, e)) = 1
576592
)
577593
or
578594
(
@@ -606,8 +622,12 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
606622
)
607623
)
608624
or
609-
// an array list expression (`[1, 2, 3]`) has the type of the first (any) element
610-
n1.(ArrayListExpr).getExpr(_) = n2 and
625+
// an array list expression (`[1, 2, 3]`) has the type of the element
626+
n1 =
627+
any(ArrayListExpr ale |
628+
ale.getAnExpr() = n2 and
629+
ale.getNumberOfExprs() = 1
630+
) and
611631
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
612632
prefix2.isEmpty()
613633
or
@@ -635,6 +655,61 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
635655
prefix2.isEmpty()
636656
}
637657

658+
/**
659+
* Holds if `child` is a child of `parent`, and the Rust compiler applies [least
660+
* upper bound (LUB) coercion](1) to infer the type of `parent` from the type of
661+
* `child`.
662+
*
663+
* In this case, we want type information to only flow from `child` to `parent`,
664+
* to avoid (a) either having to model LUB coercions, or (b) risk combinatorial
665+
* explosion in inferred types.
666+
*
667+
* [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound
668+
*/
669+
private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
670+
child = parent.(IfExpr).getABranch() and
671+
prefix.isEmpty()
672+
or
673+
parent =
674+
any(MatchExpr me |
675+
child = me.getAnArm().getExpr() and
676+
me.getNumberOfArms() > 1
677+
) and
678+
prefix.isEmpty()
679+
or
680+
parent =
681+
any(ArrayListExpr ale |
682+
child = ale.getAnExpr() and
683+
ale.getNumberOfExprs() > 1
684+
) and
685+
prefix = TypePath::singleton(TArrayTypeParameter())
686+
or
687+
bodyReturns(parent, child) and
688+
strictcount(Expr e | bodyReturns(parent, e)) > 1 and
689+
prefix.isEmpty()
690+
}
691+
692+
/**
693+
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
694+
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
695+
* `n2`.
696+
*/
697+
private predicate typeEqualityNonSymmetric(
698+
AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2
699+
) {
700+
lubCoercion(n2, n1, prefix2) and
701+
prefix1.isEmpty()
702+
or
703+
exists(AstNode mid, TypePath prefixMid, TypePath suffix |
704+
typeEquality(n1, prefixMid, mid, prefix2) or
705+
typeEquality(mid, prefix2, n1, prefixMid)
706+
|
707+
lubCoercion(mid, n2, suffix) and
708+
not lubCoercion(mid, n1, _) and
709+
prefix1 = prefixMid.append(suffix)
710+
)
711+
}
712+
638713
pragma[nomagic]
639714
private Type inferTypeEquality(AstNode n, TypePath path) {
640715
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
@@ -644,6 +719,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
644719
typeEquality(n, prefix1, n2, prefix2)
645720
or
646721
typeEquality(n2, prefix2, n, prefix1)
722+
or
723+
typeEqualityNonSymmetric(n2, prefix2, n, prefix1)
647724
)
648725
}
649726

0 commit comments

Comments
 (0)