Skip to content

[red-knot] support narrowing on constants in matches #16974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/narrow/match.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,53 @@ match x:

reveal_type(x) # revealed: object
```

## Value patterns

```py
def get_object() -> object:
return object()

x = get_object()

reveal_type(x) # revealed: object

match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"]
case 42:
reveal_type(x) # revealed: Literal[42]
case 6.0:
reveal_type(x) # revealed: float
case 1j:
reveal_type(x) # revealed: complex
case b"foo":
reveal_type(x) # revealed: Literal[b"foo"]

reveal_type(x) # revealed: object
```

## Value patterns with guard

```py
def get_object() -> object:
return object()

x = get_object()

reveal_type(x) # revealed: object

match x:
case "foo" if reveal_type(x): # revealed: Literal["foo"]
pass
case 42 if reveal_type(x): # revealed: Literal[42]
pass
case 6.0 if reveal_type(x): # revealed: float
pass
case 1j if reveal_type(x): # revealed: complex
pass
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
pass

reveal_type(x) # revealed: object
```
95 changes: 45 additions & 50 deletions crates/red_knot_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,11 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
PatternPredicateKind::Class(cls, _guard) => {
self.evaluate_match_pattern_class(subject, *cls)
}
PatternPredicateKind::Value(expr, _guard) => {
self.evaluate_match_pattern_value(subject, *expr)
}
// TODO: support more pattern kinds
PatternPredicateKind::Value(..) | PatternPredicateKind::Unsupported => None,
PatternPredicateKind::Unsupported => None,
}
}

Expand All @@ -254,29 +257,29 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}

#[track_caller]
fn expect_expr_name_symbol(&self, symbol: &str) -> ScopedSymbolId {
self.symbols()
.symbol_id_by_name(symbol)
.expect("We should always have a symbol for every `Name` node")
}

fn evaluate_expr_name(
&mut self,
expr_name: &ast::ExprName,
is_positive: bool,
) -> NarrowingConstraints<'db> {
let ast::ExprName { id, .. } = expr_name;

let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
let mut constraints = NarrowingConstraints::default();
let symbol = self.expect_expr_name_symbol(id);

constraints.insert(
symbol,
if is_positive {
Type::AlwaysFalsy.negate(self.db)
} else {
Type::AlwaysTruthy.negate(self.db)
},
);
let ty = if is_positive {
Type::AlwaysFalsy.negate(self.db)
} else {
Type::AlwaysTruthy.negate(self.db)
};

constraints
NarrowingConstraints::from_iter([(symbol, ty)])
}

fn evaluate_expr_compare(
Expand Down Expand Up @@ -335,10 +338,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
id,
ctx: _,
}) => {
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
let symbol = self.expect_expr_name_symbol(id);

match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
Expand Down Expand Up @@ -405,10 +405,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
.into_class_literal()
.is_some_and(|c| c.class().is_known(self.db, KnownClass::Type))
{
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
let symbol = self.expect_expr_name_symbol(id);
constraints.insert(symbol, Type::instance(rhs_class));
}
}
Expand Down Expand Up @@ -442,17 +439,18 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
return None;
};

let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let symbol = self.expect_expr_name_symbol(id);

let class_info_ty =
inference.expression_type(class_info.scoped_expression_id(self.db, scope));

function
.generate_constraint(self.db, class_info_ty)
.map(|constraint| {
let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, constraint.negate_if(self.db, !is_positive));
constraints
NarrowingConstraints::from_iter([(
symbol,
constraint.negate_if(self.db, !is_positive),
)])
})
}
// for the expression `bool(E)`, we further narrow the type based on `E`
Expand All @@ -476,38 +474,35 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
subject: Expression<'db>,
singleton: ast::Singleton,
) -> Option<NarrowingConstraints<'db>> {
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();

let ty = match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, ty);
Some(constraints)
} else {
None
}
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);

let ty = match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}

fn evaluate_match_pattern_class(
&mut self,
subject: Expression<'db>,
cls: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let ast::ExprName { id, .. } = subject.node_ref(self.db).as_name_expr()?;
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("We should always have a symbol for every `Name` node");
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db)?;

let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, ty);
Some(constraints)
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}

fn evaluate_match_pattern_value(
&mut self,
subject: Expression<'db>,
value: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, value);
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}

fn evaluate_bool_op(
Expand Down
Loading