[red-knot] Fix control flow for `assert` statements (#17702)

## Summary

@sharkdp and I realised in our 1:1 this morning that our control flow
for `assert` statements isn't quite accurate at the moment. Namely, for
something like this:

```py
def _(x: int | None):
    assert x is None, reveal_type(x)
```

we currently reveal `None` for `x` here, but this is incorrect. In
actual fact, the `msg` expression of an `assert` statement (the
expression after the comma) will only be evaluated if the test (`x is
None`) evaluates to `False`. As such, we should be adding a constraint
of `~None` to `x` in the `msg` expression, which should simplify the
inferred type of `x` to `int` in that context (`(int | None) & ~None` ->
`int`).

## Test Plan

Mdtests added.

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Alex Waygood 2025-04-30 08:57:49 +01:00 committed by GitHub
parent 4a621c2c12
commit 8a6787b39e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 118 additions and 10 deletions

View File

@ -51,3 +51,64 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert y not in (1, 2)
reveal_type(y) # revealed: Literal[3]
```
## Assertions with messages
```py
def _(x: int | None, y: int | None):
reveal_type(x) # revealed: int | None
assert x is None, reveal_type(x) # revealed: int
reveal_type(x) # revealed: None
reveal_type(y) # revealed: int | None
assert isinstance(y, int), reveal_type(y) # revealed: None
reveal_type(y) # revealed: int
```
## Assertions with definitions inside the message
```py
def one(x: int | None):
assert x is None, (y := x * 42) * reveal_type(y) # revealed: int
# error: [unresolved-reference]
reveal_type(y) # revealed: Unknown
def two(x: int | None, y: int | None):
assert x is None, (y := 42) * reveal_type(y) # revealed: Literal[42]
reveal_type(y) # revealed: int | None
```
## Assertions with `test` predicates that are statically known to always be `True`
```py
assert True, (x := 1)
# error: [unresolved-reference]
reveal_type(x) # revealed: Unknown
assert False, (y := 1)
# The `assert` statement is terminal if `test` resolves to `False`,
# so even though we know the `msg` branch will have been taken here
# (we know what the truthiness of `False is!), we also know that the
# `y` definition is not visible from this point in control flow
# (because this point in control flow is unreachable).
# We make sure that this does not emit an `[unresolved-reference]`
# diagnostic by adding a reachability constraint,
# but the inferred type is `Unknown`.
#
reveal_type(y) # revealed: Unknown
```
## Assertions with messages that reference definitions from the `test`
```py
def one(x: int | None):
assert (y := x), reveal_type(y) # revealed: (int & ~AlwaysTruthy) | None
reveal_type(y) # revealed: int & ~AlwaysFalsy
def two(x: int | None):
assert isinstance((y := x), int), reveal_type(y) # revealed: None
reveal_type(y) # revealed: int
```

View File

@ -362,6 +362,15 @@ def f():
ExceptionGroup
```
Similarly, assertions with statically-known falsy conditions can lead to unreachable code:
```py
def f():
assert sys.version_info > (3, 11)
ExceptionGroup
```
Finally, not that anyone would ever use it, but it also works for `while` loops:
```py

View File

@ -532,11 +532,8 @@ impl<'db> SemanticIndexBuilder<'db> {
/// Negates a predicate and adds it to the list of all predicates, does not record it.
fn add_negated_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
let negated = Predicate {
node: predicate.node,
is_positive: false,
};
self.current_use_def_map_mut().add_predicate(negated)
self.current_use_def_map_mut()
.add_predicate(predicate.negated())
}
/// Records a previously added narrowing constraint by adding it to all live bindings.
@ -1383,14 +1380,46 @@ where
}
}
ast::Stmt::Assert(node) => {
self.visit_expr(&node.test);
let predicate = self.record_expression_narrowing_constraint(&node.test);
self.record_visibility_constraint(predicate);
ast::Stmt::Assert(ast::StmtAssert {
test,
msg,
range: _,
}) => {
// We model an `assert test, msg` statement here. Conceptually, we can think of
// this as being equivalent to the following:
//
// ```py
// if not test:
// msg
// <halt>
//
// <whatever code comes after>
// ```
//
// Importantly, the `msg` expression is only evaluated if the `test` expression is
// falsy. This is why we apply the negated `test` predicate as a narrowing and
// reachability constraint on the `msg` expression.
//
// The other important part is the `<halt>`. This lets us skip the usual merging of
// flow states and simplification of visibility constraints, since there is no way
// of getting out of that `msg` branch. We simply restore to the post-test state.
if let Some(msg) = &node.msg {
self.visit_expr(test);
let predicate = self.build_predicate(test);
if let Some(msg) = msg {
let post_test = self.flow_snapshot();
let negated_predicate = predicate.negated();
self.record_narrowing_constraint(negated_predicate);
self.record_reachability_constraint(negated_predicate);
self.visit_expr(msg);
self.record_visibility_constraint(negated_predicate);
self.flow_restore(post_test);
}
self.record_narrowing_constraint(predicate);
self.record_visibility_constraint(predicate);
self.record_reachability_constraint(predicate);
}
ast::Stmt::Assign(node) => {

View File

@ -49,6 +49,15 @@ pub(crate) struct Predicate<'db> {
pub(crate) is_positive: bool,
}
impl Predicate<'_> {
pub(crate) fn negated(self) -> Self {
Self {
node: self.node,
is_positive: !self.is_positive,
}
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update)]
pub(crate) enum PredicateNode<'db> {
Expression(Expression<'db>),