Skip to content

Rust: Account for borrows in operators in type inference #19789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/ql/lib/codeql/rust/controlflow/CfgNodes.qll
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ final class CallCfgNode extends ExprCfgNode {
}

/** Gets the `i`th argument of this call, if any. */
ExprCfgNode getArgument(int i) {
any(ChildMapping mapping).hasCfgChild(node, node.getArgument(i), this, result)
ExprCfgNode getPositionalArgument(int i) {
any(ChildMapping mapping).hasCfgChild(node, node.getPositionalArgument(i), this, result)
}
}

Expand Down
4 changes: 2 additions & 2 deletions rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ final class ParameterPosition extends TParameterPosition {
final class ArgumentPosition extends ParameterPosition {
/** Gets the argument of `call` at this position, if any. */
Expr getArgument(Call call) {
result = call.getArgument(this.getPosition())
result = call.getPositionalArgument(this.getPosition())
or
result = call.getReceiver() and this.isSelf()
}
Expand All @@ -146,7 +146,7 @@ final class ArgumentPosition extends ParameterPosition {
* as the synthetic `ReceiverNode` is the argument for the `self` parameter.
*/
predicate isArgumentForCall(ExprCfgNode arg, CallCfgNode call, ParameterPosition pos) {
call.getArgument(pos.getPosition()) = arg
call.getPositionalArgument(pos.getPosition()) = arg
or
call.getReceiver() = arg and pos.isSelf() and not call.getCall().receiverImplicitlyBorrowed()
}
Expand Down
2 changes: 2 additions & 0 deletions rust/ql/lib/codeql/rust/elements/Call.qll
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@

private import internal.CallImpl

final class ArgumentPosition = Impl::ArgumentPosition;

final class Call = Impl::Call;
102 changes: 67 additions & 35 deletions rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,58 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
private import codeql.rust.elements.Operation

module Impl {
newtype TArgumentPosition =
Copy link
Preview

Copilot AI Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider extracting TArgumentPosition and the ArgumentPosition class into their own file (e.g., ArgumentPosition.qll) to improve modularity and reduce the size of CallImpl.qll.

Copilot uses AI. Check for mistakes.

TPositionalArgumentPosition(int i) {
i in [0 .. max([any(ParamList l).getNumberOfParams(), any(ArgList l).getNumberOfArgs()]) - 1]
} or
TSelfArgumentPosition()

/** An argument position in a call. */
class ArgumentPosition extends TArgumentPosition {
/** Gets the index of the argument in the call, if this is a positional argument. */
int asPosition() { this = TPositionalArgumentPosition(result) }

/** Holds if this call position is a self argument. */
predicate isSelf() { this instanceof TSelfArgumentPosition }

/** Gets a string representation of this argument position. */
string toString() {
result = this.asPosition().toString()
or
this.isSelf() and result = "self"
}
}

/**
* An expression that calls a function.
*
* This class abstracts over the different ways in which a function can be called in Rust.
* This class abstracts over the different ways in which a function can be
* called in Rust.
*/
abstract class Call extends ExprImpl::Expr {
/** Gets the number of arguments _excluding_ any `self` argument. */
abstract int getNumberOfArguments();

/** Gets the receiver of this call if it is a method call. */
abstract Expr getReceiver();

/** Holds if the call has a receiver that might be implicitly borrowed. */
abstract predicate receiverImplicitlyBorrowed();
/** Holds if the receiver of this call is implicitly borrowed. */
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition()) }

/** Gets the trait targeted by this call, if any. */
abstract Trait getTrait();

/** Gets the name of the method called if this call is a method call. */
abstract string getMethodName();

/** Gets the argument at the given position, if any. */
abstract Expr getArgument(ArgumentPosition pos);

/** Holds if the argument at `pos` might be implicitly borrowed. */
abstract predicate implicitBorrowAt(ArgumentPosition pos);

/** Gets the number of arguments _excluding_ any `self` argument. */
int getNumberOfArguments() { result = count(this.getArgument(TPositionalArgumentPosition(_))) }

/** Gets the `i`th argument of this call, if any. */
abstract Expr getArgument(int i);
Expr getPositionalArgument(int i) { result = this.getArgument(TPositionalArgumentPosition(i)) }

/** Gets the receiver of this call if it is a method call. */
Expr getReceiver() { result = this.getArgument(TSelfArgumentPosition()) }

/** Gets the static target of this call, if any. */
Function getStaticTarget() {
Expand All @@ -54,15 +83,13 @@ module Impl {

override string getMethodName() { none() }

override Expr getReceiver() { none() }

override Trait getTrait() { none() }

override predicate receiverImplicitlyBorrowed() { none() }

override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }

override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
override Expr getArgument(ArgumentPosition pos) {
result = super.getArgList().getArg(pos.asPosition())
}
}

private class CallExprMethodCall extends Call instanceof CallExpr {
Expand All @@ -73,8 +100,6 @@ module Impl {

override string getMethodName() { result = methodName }

override Expr getReceiver() { result = super.getArgList().getArg(0) }

override Trait getTrait() {
result = resolvePath(qualifier) and
// When the qualifier is `Self` and resolves to a trait, it's inside a
Expand All @@ -84,43 +109,50 @@ module Impl {
qualifier.toString() != "Self"
}

override predicate receiverImplicitlyBorrowed() { none() }
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }

override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() - 1 }

override Expr getArgument(int i) { result = super.getArgList().getArg(i + 1) }
override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = super.getArgList().getArg(0)
or
result = super.getArgList().getArg(pos.asPosition() + 1)
}
}

private class MethodCallExprCall extends Call instanceof MethodCallExpr {
override string getMethodName() { result = super.getIdentifier().getText() }

override Expr getReceiver() { result = this.(MethodCallExpr).getReceiver() }

override Trait getTrait() { none() }

override predicate receiverImplicitlyBorrowed() { any() }
override predicate implicitBorrowAt(ArgumentPosition pos) { pos.isSelf() }

override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }

override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = this.(MethodCallExpr).getReceiver()
or
result = super.getArgList().getArg(pos.asPosition())
}
}

private class OperatorCall extends Call instanceof Operation {
Trait trait;
string methodName;
int borrows;

OperatorCall() { super.isOverloaded(trait, methodName) }
OperatorCall() { super.isOverloaded(trait, methodName, borrows) }

override string getMethodName() { result = methodName }

override Expr getReceiver() { result = super.getOperand(0) }

override Trait getTrait() { result = trait }

override predicate receiverImplicitlyBorrowed() { none() }

override int getNumberOfArguments() { result = super.getNumberOfOperands() - 1 }
override predicate implicitBorrowAt(ArgumentPosition pos) {
pos.isSelf() and borrows >= 1
or
pos.asPosition() = 0 and borrows = 2
}

override Expr getArgument(int i) { result = super.getOperand(1) and i = 0 }
override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = super.getOperand(0)
or
pos.asPosition() = 0 and result = super.getOperand(1)
}
}
}
67 changes: 34 additions & 33 deletions rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -9,79 +9,80 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl

/**
* Holds if the operator `op` with arity `arity` is overloaded to a trait with
* the canonical path `path` and the method name `method`.
* the canonical path `path` and the method name `method`, and if it borrows its
* first `borrows` arguments.
*/
private predicate isOverloaded(string op, int arity, string path, string method) {
private predicate isOverloaded(string op, int arity, string path, string method, int borrows) {
arity = 1 and
(
// Negation
op = "-" and path = "core::ops::arith::Neg" and method = "neg"
op = "-" and path = "core::ops::arith::Neg" and method = "neg" and borrows = 0
or
// Not
op = "!" and path = "core::ops::bit::Not" and method = "not"
op = "!" and path = "core::ops::bit::Not" and method = "not" and borrows = 0
or
// Dereference
op = "*" and path = "core::ops::deref::Deref" and method = "deref"
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, * borrows it's first argument. But setting borrows = 1 causes problems, so we are not ready for that yet.

)
or
arity = 2 and
(
// Comparison operators
op = "==" and path = "core::cmp::PartialEq" and method = "eq"
op = "==" and path = "core::cmp::PartialEq" and method = "eq" and borrows = 2
or
op = "!=" and path = "core::cmp::PartialEq" and method = "ne"
op = "!=" and path = "core::cmp::PartialEq" and method = "ne" and borrows = 2
or
op = "<" and path = "core::cmp::PartialOrd" and method = "lt"
op = "<" and path = "core::cmp::PartialOrd" and method = "lt" and borrows = 2
or
op = "<=" and path = "core::cmp::PartialOrd" and method = "le"
op = "<=" and path = "core::cmp::PartialOrd" and method = "le" and borrows = 2
or
op = ">" and path = "core::cmp::PartialOrd" and method = "gt"
op = ">" and path = "core::cmp::PartialOrd" and method = "gt" and borrows = 2
or
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge"
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge" and borrows = 2
or
// Arithmetic operators
op = "+" and path = "core::ops::arith::Add" and method = "add"
op = "+" and path = "core::ops::arith::Add" and method = "add" and borrows = 0
or
op = "-" and path = "core::ops::arith::Sub" and method = "sub"
op = "-" and path = "core::ops::arith::Sub" and method = "sub" and borrows = 0
or
op = "*" and path = "core::ops::arith::Mul" and method = "mul"
op = "*" and path = "core::ops::arith::Mul" and method = "mul" and borrows = 0
or
op = "/" and path = "core::ops::arith::Div" and method = "div"
op = "/" and path = "core::ops::arith::Div" and method = "div" and borrows = 0
or
op = "%" and path = "core::ops::arith::Rem" and method = "rem"
op = "%" and path = "core::ops::arith::Rem" and method = "rem" and borrows = 0
or
// Arithmetic assignment expressions
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign"
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign" and borrows = 1
or
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign"
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign" and borrows = 1
or
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign"
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign" and borrows = 1
or
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign"
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign" and borrows = 1
or
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign"
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign" and borrows = 1
or
// Bitwise operators
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand"
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand" and borrows = 0
or
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor"
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor" and borrows = 0
or
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor"
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor" and borrows = 0
or
op = "<<" and path = "core::ops::bit::Shl" and method = "shl"
op = "<<" and path = "core::ops::bit::Shl" and method = "shl" and borrows = 0
or
op = ">>" and path = "core::ops::bit::Shr" and method = "shr"
op = ">>" and path = "core::ops::bit::Shr" and method = "shr" and borrows = 0
or
// Bitwise assignment operators
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign"
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign" and borrows = 1
or
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign"
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign" and borrows = 1
or
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign"
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign" and borrows = 1
or
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign"
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign" and borrows = 1
or
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign"
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign" and borrows = 1
)
}

Expand Down Expand Up @@ -114,9 +115,9 @@ module Impl {
* Holds if this operation is overloaded to the method `methodName` of the
* trait `trait`.
*/
predicate isOverloaded(Trait trait, string methodName) {
predicate isOverloaded(Trait trait, string methodName, int borrows) {
isOverloaded(this.getOperatorName(), this.getNumberOfOperands(), trait.getCanonicalPath(),
methodName)
methodName, borrows)
}
}
}
Loading