@@ -524,6 +524,17 @@ private Struct getRangeType(RangeExpr re) {
524
524
result instanceof RangeToInclusiveStruct
525
525
}
526
526
527
+ private predicate bodyReturns ( Expr body , Expr e ) {
528
+ exists ( ReturnExpr re , Callable c |
529
+ e = re .getExpr ( ) and
530
+ c = re .getEnclosingCallable ( )
531
+ |
532
+ body = c .( Function ) .getBody ( )
533
+ or
534
+ body = c .( ClosureExpr ) .getBody ( )
535
+ )
536
+ }
537
+
527
538
/**
528
539
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
529
540
* of `n2` at `prefix2` and type information should propagate in both directions
@@ -540,9 +551,11 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
540
551
let .getInitializer ( ) = n2
541
552
)
542
553
or
543
- n1 = n2 .( IfExpr ) .getABranch ( )
544
- or
545
- n1 = n2 .( MatchExpr ) .getAnArm ( ) .getExpr ( )
554
+ n2 =
555
+ any ( MatchExpr me |
556
+ n1 = me .getAnArm ( ) .getExpr ( ) and
557
+ me .getNumberOfArms ( ) = 1
558
+ )
546
559
or
547
560
exists ( LetExpr let |
548
561
n1 = let .getScrutinee ( ) and
@@ -573,6 +586,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
573
586
n1 = n2 .( MacroExpr ) .getMacroCall ( ) .getMacroCallExpansion ( )
574
587
or
575
588
n1 = n2 .( MacroPat ) .getMacroCall ( ) .getMacroCallExpansion ( )
589
+ or
590
+ bodyReturns ( n1 , n2 ) and
591
+ strictcount ( Expr e | bodyReturns ( n1 , e ) ) = 1
576
592
)
577
593
or
578
594
(
@@ -606,8 +622,12 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
606
622
)
607
623
)
608
624
or
609
- // an array list expression (`[1, 2, 3]`) has the type of the first (any) element
610
- n1 .( ArrayListExpr ) .getExpr ( _) = n2 and
625
+ // an array list expression (`[1, 2, 3]`) has the type of the element
626
+ n1 =
627
+ any ( ArrayListExpr ale |
628
+ ale .getAnExpr ( ) = n2 and
629
+ ale .getNumberOfExprs ( ) = 1
630
+ ) and
611
631
prefix1 = TypePath:: singleton ( TArrayTypeParameter ( ) ) and
612
632
prefix2 .isEmpty ( )
613
633
or
@@ -635,6 +655,61 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
635
655
prefix2 .isEmpty ( )
636
656
}
637
657
658
+ /**
659
+ * Holds if `child` is a child of `parent`, and the Rust compiler applies [least
660
+ * upper bound (LUB) coercion](1) to infer the type of `parent` from the type of
661
+ * `child`.
662
+ *
663
+ * In this case, we want type information to only flow from `child` to `parent`,
664
+ * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial
665
+ * explosion in inferred types.
666
+ *
667
+ * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound
668
+ */
669
+ private predicate lubCoercion ( AstNode parent , AstNode child , TypePath prefix ) {
670
+ child = parent .( IfExpr ) .getABranch ( ) and
671
+ prefix .isEmpty ( )
672
+ or
673
+ parent =
674
+ any ( MatchExpr me |
675
+ child = me .getAnArm ( ) .getExpr ( ) and
676
+ me .getNumberOfArms ( ) > 1
677
+ ) and
678
+ prefix .isEmpty ( )
679
+ or
680
+ parent =
681
+ any ( ArrayListExpr ale |
682
+ child = ale .getAnExpr ( ) and
683
+ ale .getNumberOfExprs ( ) > 1
684
+ ) and
685
+ prefix = TypePath:: singleton ( TArrayTypeParameter ( ) )
686
+ or
687
+ bodyReturns ( parent , child ) and
688
+ strictcount ( Expr e | bodyReturns ( parent , e ) ) > 1 and
689
+ prefix .isEmpty ( )
690
+ }
691
+
692
+ /**
693
+ * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
694
+ * of `n2` at `prefix2`, but type information should only propagate from `n1` to
695
+ * `n2`.
696
+ */
697
+ private predicate typeEqualityNonSymmetric (
698
+ AstNode n1 , TypePath prefix1 , AstNode n2 , TypePath prefix2
699
+ ) {
700
+ lubCoercion ( n2 , n1 , prefix2 ) and
701
+ prefix1 .isEmpty ( )
702
+ or
703
+ exists ( AstNode mid , TypePath prefixMid , TypePath suffix |
704
+ typeEquality ( n1 , prefixMid , mid , prefix2 ) or
705
+ typeEquality ( mid , prefix2 , n1 , prefixMid )
706
+ |
707
+ lubCoercion ( mid , n2 , suffix ) and
708
+ not lubCoercion ( mid , n1 , _) and
709
+ prefix1 = prefixMid .append ( suffix )
710
+ )
711
+ }
712
+
638
713
pragma [ nomagic]
639
714
private Type inferTypeEquality ( AstNode n , TypePath path ) {
640
715
exists ( TypePath prefix1 , AstNode n2 , TypePath prefix2 , TypePath suffix |
@@ -644,6 +719,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
644
719
typeEquality ( n , prefix1 , n2 , prefix2 )
645
720
or
646
721
typeEquality ( n2 , prefix2 , n , prefix1 )
722
+ or
723
+ typeEqualityNonSymmetric ( n2 , prefix2 , n , prefix1 )
647
724
)
648
725
}
649
726
0 commit comments