Skip to content

Rust: Improve type inference for for loops and range expressions #19971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions rust/ql/lib/codeql/rust/elements/RangeExprExt.qll
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* This module provides sub classes of the `RangeExpr` class.
*/

private import rust

/**
* A range-from expression. For example:
* ```rust
* let x = 10..;
* ```
*/
final class RangeFromExpr extends RangeExpr {
RangeFromExpr() {
this.getOperatorName() = ".." and
not this.hasEnd()
}
}

/**
* A range-to expression. For example:
* ```rust
* let x = ..10;
* ```
*/
final class RangeToExpr extends RangeExpr {
RangeToExpr() {
this.getOperatorName() = ".." and
not this.hasStart()
}
}

/**
* A range-from-to expression. For example:
* ```rust
* let x = 10..20;
* ```
*/
final class RangeFromToExpr extends RangeExpr {
RangeFromToExpr() {
this.getOperatorName() = ".." and
this.hasStart() and
this.hasEnd()
}
}

/**
* A range-inclusive expression. For example:
* ```rust
* let x = 1..=10;
* ```
*/
final class RangeInclusiveExpr extends RangeExpr {
RangeInclusiveExpr() {
this.getOperatorName() = "..=" and
this.hasStart() and
this.hasEnd()
}
}

/**
* A range-to-inclusive expression. For example:
* ```rust
* let x = ..=10;
* ```
*/
final class RangeToInclusiveExpr extends RangeExpr {
RangeToInclusiveExpr() {
this.getOperatorName() = "..=" and
not this.hasStart() and
this.hasEnd()
}
}
98 changes: 98 additions & 0 deletions rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,72 @@ class ResultEnum extends Enum {
Variant getErr() { result = this.getVariant("Err") }
}

/**
* The [`Range` struct][1].
*
* [1]: https://doc.rust-lang.org/core/ops/struct.Range.html
*/
class RangeStruct extends Struct {
RangeStruct() { this.getCanonicalPath() = "core::ops::range::Range" }

/** Gets the `start` field. */
StructField getStart() { result = this.getStructField("start") }

/** Gets the `end` field. */
StructField getEnd() { result = this.getStructField("end") }
}

/**
* The [`RangeFrom` struct][1].
*
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeFrom.html
*/
class RangeFromStruct extends Struct {
RangeFromStruct() { this.getCanonicalPath() = "core::ops::range::RangeFrom" }

/** Gets the `start` field. */
StructField getStart() { result = this.getStructField("start") }
}

/**
* The [`RangeTo` struct][1].
*
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeTo.html
*/
class RangeToStruct extends Struct {
RangeToStruct() { this.getCanonicalPath() = "core::ops::range::RangeTo" }

/** Gets the `end` field. */
StructField getEnd() { result = this.getStructField("end") }
}

/**
* The [`RangeInclusive` struct][1].
*
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeInclusive.html
*/
class RangeInclusiveStruct extends Struct {
RangeInclusiveStruct() { this.getCanonicalPath() = "core::ops::range::RangeInclusive" }

/** Gets the `start` field. */
StructField getStart() { result = this.getStructField("start") }

/** Gets the `end` field. */
StructField getEnd() { result = this.getStructField("end") }
}

/**
* The [`RangeToInclusive` struct][1].
*
* [1]: https://doc.rust-lang.org/core/ops/struct.RangeToInclusive.html
*/
class RangeToInclusiveStruct extends Struct {
RangeToInclusiveStruct() { this.getCanonicalPath() = "core::ops::range::RangeToInclusive" }

/** Gets the `end` field. */
StructField getEnd() { result = this.getStructField("end") }
}

/**
* The [`Future` trait][1].
*
Expand All @@ -66,6 +132,38 @@ class FutureTrait extends Trait {
}
}

/**
* The [`Iterator` trait][1].
*
* [1]: https://doc.rust-lang.org/std/iter/trait.Iterator.html
*/
class IteratorTrait extends Trait {
IteratorTrait() { this.getCanonicalPath() = "core::iter::traits::iterator::Iterator" }

/** Gets the `Item` associated type. */
pragma[nomagic]
TypeAlias getItemType() {
result = this.getAssocItemList().getAnAssocItem() and
result.getName().getText() = "Item"
}
}

/**
* The [`IntoIterator` trait][1].
*
* [1]: https://doc.rust-lang.org/std/iter/trait.IntoIterator.html
*/
class IntoIteratorTrait extends Trait {
IntoIteratorTrait() { this.getCanonicalPath() = "core::iter::traits::collect::IntoIterator" }

/** Gets the `Item` associated type. */
pragma[nomagic]
TypeAlias getItemType() {
result = this.getAssocItemList().getAnAssocItem() and
result.getName().getText() = "Item"
}
}

/**
* The [`String` struct][1].
*
Expand Down
16 changes: 10 additions & 6 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -421,21 +421,25 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
}

final class TraitTypeAbstraction extends TypeAbstraction, Trait {
override TypeParamTypeParameter getATypeParameter() {
result.getTypeParam() = this.getGenericParamList().getATypeParam()
override TypeParameter getATypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()
or
result.(AssociatedTypeTypeParameter).getTrait() = this
}
}

final class TypeBoundTypeAbstraction extends TypeAbstraction, TypeBound {
override TypeParamTypeParameter getATypeParameter() { none() }
override TypeParameter getATypeParameter() { none() }
}

final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
SelfTypeBoundTypeAbstraction() { any(Trait trait).getName() = this }
private TraitTypeAbstraction trait;

SelfTypeBoundTypeAbstraction() { trait.getName() = this }

override TypeParamTypeParameter getATypeParameter() { none() }
override TypeParameter getATypeParameter() { none() }
}

final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
override TypeParamTypeParameter getATypeParameter() { none() }
override TypeParameter getATypeParameter() { none() }
}
101 changes: 87 additions & 14 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,27 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
prefix2.isEmpty()
or
exists(Struct s |
n2 = [n1.(RangeExpr).getStart(), n1.(RangeExpr).getEnd()] and
prefix1 = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
prefix2.isEmpty()
|
n1 instanceof RangeFromExpr and
s instanceof RangeFromStruct
or
n1 instanceof RangeToExpr and
s instanceof RangeToStruct
or
n1 instanceof RangeFromToExpr and
s instanceof RangeStruct
or
n1 instanceof RangeInclusiveExpr and
s instanceof RangeInclusiveStruct
or
n1 instanceof RangeToInclusiveExpr and
s instanceof RangeToInclusiveStruct
)
}

pragma[nomagic]
Expand Down Expand Up @@ -1062,7 +1083,7 @@ private TraitType inferAsyncBlockExprRootType(AsyncBlockExpr abe) {
result = getFutureTraitType()
}

final class AwaitTarget extends Expr {
final private class AwaitTarget extends Expr {
AwaitTarget() { this = any(AwaitExpr ae).getExpr() }

Type getTypeAt(TypePath path) { result = inferType(this, path) }
Expand Down Expand Up @@ -1098,6 +1119,29 @@ private class Vec extends Struct {
pragma[nomagic]
private Type inferArrayExprType(ArrayExpr ae) { exists(ae) and result = TArrayType() }

/**
* Gets the root type of the range expression `re`.
*/
pragma[nomagic]
private Type inferRangeExprType(RangeExpr re) {
exists(Struct s | result = TStruct(s) |
re instanceof RangeFromExpr and
s instanceof RangeFromStruct
or
re instanceof RangeToExpr and
s instanceof RangeToStruct
or
re instanceof RangeFromToExpr and
s instanceof RangeStruct
or
re instanceof RangeInclusiveExpr and
s instanceof RangeInclusiveStruct
or
re instanceof RangeToInclusiveExpr and
s instanceof RangeToInclusiveStruct
)
}

/**
* According to [the Rust reference][1]: _"array and slice-typed expressions
* can be indexed with a `usize` index ... For other types an index expression
Expand Down Expand Up @@ -1134,23 +1178,49 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
)
}

final private class ForIterableExpr extends Expr {
ForIterableExpr() { this = any(ForExpr fe).getIterable() }

Type getTypeAt(TypePath path) { result = inferType(this, path) }
}

private module ForIterableSatisfiesConstraintInput implements
SatisfiesConstraintInputSig<ForIterableExpr>
{
predicate relevantConstraint(ForIterableExpr term, Type constraint) {
exists(term) and
exists(Trait t | t = constraint.(TraitType).getTrait() |
// TODO: Remove the line below once we can handle the `impl<I: Iterator> IntoIterator for I` implementation
t instanceof IteratorTrait or
t instanceof IntoIteratorTrait
)
}
}

pragma[nomagic]
private AssociatedTypeTypeParameter getIteratorItemTypeParameter() {
result.getTypeAlias() = any(IteratorTrait t).getItemType()
}

pragma[nomagic]
private AssociatedTypeTypeParameter getIntoIteratorItemTypeParameter() {
result.getTypeAlias() = any(IntoIteratorTrait t).getItemType()
}

pragma[nomagic]
private Type inferForLoopExprType(AstNode n, TypePath path) {
// type of iterable -> type of pattern (loop variable)
exists(ForExpr fe, Type iterableType, TypePath iterablePath |
exists(ForExpr fe, TypePath exprPath, AssociatedTypeTypeParameter tp |
n = fe.getPat() and
iterableType = inferType(fe.getIterable(), iterablePath) and
result = iterableType and
(
iterablePath.isCons(any(Vec v).getElementTypeParameter(), path)
or
iterablePath.isCons(any(ArrayTypeParameter tp), path)
or
iterablePath
.stripPrefix(TypePath::cons(TRefTypeParameter(),
TypePath::singleton(any(SliceTypeParameter tp)))) = path
// TODO: iterables (general case for containers, ranges etc)
)
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>::satisfiesConstraintType(fe.getIterable(),
_, exprPath, result) and
exprPath.isCons(tp, path)
|
tp = getIntoIteratorItemTypeParameter()
or
// TODO: Remove once we can handle the `impl<I: Iterator> IntoIterator for I` implementation
tp = getIteratorItemTypeParameter() and
inferType(fe.getIterable()) != TArrayType()
)
}

Expand Down Expand Up @@ -1589,6 +1659,9 @@ private module Cached {
result = inferArrayExprType(n) and
path.isEmpty()
or
result = inferRangeExprType(n) and
path.isEmpty()
or
result = inferIndexExprType(n, path)
or
result = inferForLoopExprType(n, path)
Expand Down
1 change: 1 addition & 0 deletions rust/ql/lib/rust.qll
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ import codeql.rust.elements.AsyncBlockExpr
import codeql.rust.elements.Variable
import codeql.rust.elements.NamedFormatArgument
import codeql.rust.elements.PositionalFormatArgument
import codeql.rust.elements.RangeExprExt
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@
| test.rs:412:31:412:38 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:417:22:417:39 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:417:22:417:39 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:423:22:423:25 | path | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:424:27:424:35 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:430:22:430:34 | ...::read_link | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:439:31:439:45 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:444:31:444:45 | ...::read | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:449:22:449:46 | ...::read_to_string | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:455:26:455:29 | path | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:455:26:455:29 | path | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:456:31:456:39 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:456:31:456:39 | file_name | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:462:22:462:41 | ...::read_link | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:472:20:472:38 | ...::open | Flow source 'FileSource' of type file (DEFAULT). |
| test.rs:506:21:506:39 | ...::open | Flow source 'FileSource' of type file (DEFAULT). |
Expand Down
Loading