mirror of https://github.com/astral-sh/ruff
[red-knot] Fix control flow for `assert` statements (#17702)
## Summary
@sharkdp and I realised in our 1:1 this morning that our control flow
for `assert` statements isn't quite accurate at the moment. Namely, for
something like this:
```py
def _(x: int | None):
assert x is None, reveal_type(x)
```
we currently reveal `None` for `x` here, but this is incorrect. In
actual fact, the `msg` expression of an `assert` statement (the
expression after the comma) will only be evaluated if the test (`x is
None`) evaluates to `False`. As such, we should be adding a constraint
of `~None` to `x` in the `msg` expression, which should simplify the
inferred type of `x` to `int` in that context (`(int | None) & ~None` ->
`int`).
## Test Plan
Mdtests added.
---------
Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
parent
4a621c2c12
commit
8a6787b39e
|
|
@ -51,3 +51,64 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
|
|||
assert y not in (1, 2)
|
||||
reveal_type(y) # revealed: Literal[3]
|
||||
```
|
||||
|
||||
## Assertions with messages
|
||||
|
||||
```py
|
||||
def _(x: int | None, y: int | None):
|
||||
reveal_type(x) # revealed: int | None
|
||||
assert x is None, reveal_type(x) # revealed: int
|
||||
reveal_type(x) # revealed: None
|
||||
|
||||
reveal_type(y) # revealed: int | None
|
||||
assert isinstance(y, int), reveal_type(y) # revealed: None
|
||||
reveal_type(y) # revealed: int
|
||||
```
|
||||
|
||||
## Assertions with definitions inside the message
|
||||
|
||||
```py
|
||||
def one(x: int | None):
|
||||
assert x is None, (y := x * 42) * reveal_type(y) # revealed: int
|
||||
|
||||
# error: [unresolved-reference]
|
||||
reveal_type(y) # revealed: Unknown
|
||||
|
||||
def two(x: int | None, y: int | None):
|
||||
assert x is None, (y := 42) * reveal_type(y) # revealed: Literal[42]
|
||||
reveal_type(y) # revealed: int | None
|
||||
```
|
||||
|
||||
## Assertions with `test` predicates that are statically known to always be `True`
|
||||
|
||||
```py
|
||||
assert True, (x := 1)
|
||||
|
||||
# error: [unresolved-reference]
|
||||
reveal_type(x) # revealed: Unknown
|
||||
|
||||
assert False, (y := 1)
|
||||
|
||||
# The `assert` statement is terminal if `test` resolves to `False`,
|
||||
# so even though we know the `msg` branch will have been taken here
|
||||
# (we know what the truthiness of `False is!), we also know that the
|
||||
# `y` definition is not visible from this point in control flow
|
||||
# (because this point in control flow is unreachable).
|
||||
# We make sure that this does not emit an `[unresolved-reference]`
|
||||
# diagnostic by adding a reachability constraint,
|
||||
# but the inferred type is `Unknown`.
|
||||
#
|
||||
reveal_type(y) # revealed: Unknown
|
||||
```
|
||||
|
||||
## Assertions with messages that reference definitions from the `test`
|
||||
|
||||
```py
|
||||
def one(x: int | None):
|
||||
assert (y := x), reveal_type(y) # revealed: (int & ~AlwaysTruthy) | None
|
||||
reveal_type(y) # revealed: int & ~AlwaysFalsy
|
||||
|
||||
def two(x: int | None):
|
||||
assert isinstance((y := x), int), reveal_type(y) # revealed: None
|
||||
reveal_type(y) # revealed: int
|
||||
```
|
||||
|
|
|
|||
|
|
@ -362,6 +362,15 @@ def f():
|
|||
ExceptionGroup
|
||||
```
|
||||
|
||||
Similarly, assertions with statically-known falsy conditions can lead to unreachable code:
|
||||
|
||||
```py
|
||||
def f():
|
||||
assert sys.version_info > (3, 11)
|
||||
|
||||
ExceptionGroup
|
||||
```
|
||||
|
||||
Finally, not that anyone would ever use it, but it also works for `while` loops:
|
||||
|
||||
```py
|
||||
|
|
|
|||
|
|
@ -532,11 +532,8 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
|
||||
/// Negates a predicate and adds it to the list of all predicates, does not record it.
|
||||
fn add_negated_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
|
||||
let negated = Predicate {
|
||||
node: predicate.node,
|
||||
is_positive: false,
|
||||
};
|
||||
self.current_use_def_map_mut().add_predicate(negated)
|
||||
self.current_use_def_map_mut()
|
||||
.add_predicate(predicate.negated())
|
||||
}
|
||||
|
||||
/// Records a previously added narrowing constraint by adding it to all live bindings.
|
||||
|
|
@ -1383,14 +1380,46 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
ast::Stmt::Assert(node) => {
|
||||
self.visit_expr(&node.test);
|
||||
let predicate = self.record_expression_narrowing_constraint(&node.test);
|
||||
self.record_visibility_constraint(predicate);
|
||||
ast::Stmt::Assert(ast::StmtAssert {
|
||||
test,
|
||||
msg,
|
||||
range: _,
|
||||
}) => {
|
||||
// We model an `assert test, msg` statement here. Conceptually, we can think of
|
||||
// this as being equivalent to the following:
|
||||
//
|
||||
// ```py
|
||||
// if not test:
|
||||
// msg
|
||||
// <halt>
|
||||
//
|
||||
// <whatever code comes after>
|
||||
// ```
|
||||
//
|
||||
// Importantly, the `msg` expression is only evaluated if the `test` expression is
|
||||
// falsy. This is why we apply the negated `test` predicate as a narrowing and
|
||||
// reachability constraint on the `msg` expression.
|
||||
//
|
||||
// The other important part is the `<halt>`. This lets us skip the usual merging of
|
||||
// flow states and simplification of visibility constraints, since there is no way
|
||||
// of getting out of that `msg` branch. We simply restore to the post-test state.
|
||||
|
||||
if let Some(msg) = &node.msg {
|
||||
self.visit_expr(test);
|
||||
let predicate = self.build_predicate(test);
|
||||
|
||||
if let Some(msg) = msg {
|
||||
let post_test = self.flow_snapshot();
|
||||
let negated_predicate = predicate.negated();
|
||||
self.record_narrowing_constraint(negated_predicate);
|
||||
self.record_reachability_constraint(negated_predicate);
|
||||
self.visit_expr(msg);
|
||||
self.record_visibility_constraint(negated_predicate);
|
||||
self.flow_restore(post_test);
|
||||
}
|
||||
|
||||
self.record_narrowing_constraint(predicate);
|
||||
self.record_visibility_constraint(predicate);
|
||||
self.record_reachability_constraint(predicate);
|
||||
}
|
||||
|
||||
ast::Stmt::Assign(node) => {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,15 @@ pub(crate) struct Predicate<'db> {
|
|||
pub(crate) is_positive: bool,
|
||||
}
|
||||
|
||||
impl Predicate<'_> {
|
||||
pub(crate) fn negated(self) -> Self {
|
||||
Self {
|
||||
node: self.node,
|
||||
is_positive: !self.is_positive,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update)]
|
||||
pub(crate) enum PredicateNode<'db> {
|
||||
Expression(Expression<'db>),
|
||||
|
|
|
|||
Loading…
Reference in New Issue