From 8a6787b39e178ff829d58d775b4d65d83b061c53 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 30 Apr 2025 08:57:49 +0100 Subject: [PATCH] [red-knot] Fix control flow for `assert` statements (#17702) ## Summary @sharkdp and I realised in our 1:1 this morning that our control flow for `assert` statements isn't quite accurate at the moment. Namely, for something like this: ```py def _(x: int | None): assert x is None, reveal_type(x) ``` we currently reveal `None` for `x` here, but this is incorrect. In actual fact, the `msg` expression of an `assert` statement (the expression after the comma) will only be evaluated if the test (`x is None`) evaluates to `False`. As such, we should be adding a constraint of `~None` to `x` in the `msg` expression, which should simplify the inferred type of `x` to `int` in that context (`(int | None) & ~None` -> `int`). ## Test Plan Mdtests added. --------- Co-authored-by: David Peter --- .../resources/mdtest/narrow/assert.md | 61 +++++++++++++++++++ .../resources/mdtest/unreachable.md | 9 +++ .../src/semantic_index/builder.rs | 49 ++++++++++++--- .../src/semantic_index/predicate.rs | 9 +++ 4 files changed, 118 insertions(+), 10 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md index c452e3c71d..0fab83880e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md @@ -51,3 +51,64 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]): assert y not in (1, 2) reveal_type(y) # revealed: Literal[3] ``` + +## Assertions with messages + +```py +def _(x: int | None, y: int | None): + reveal_type(x) # revealed: int | None + assert x is None, reveal_type(x) # revealed: int + reveal_type(x) # revealed: None + + reveal_type(y) # revealed: int | None + assert isinstance(y, int), reveal_type(y) # revealed: None + reveal_type(y) # revealed: int +``` + +## Assertions with definitions inside the message + +```py +def one(x: int | None): + assert x is None, (y := x * 42) * reveal_type(y) # revealed: int + + # error: [unresolved-reference] + reveal_type(y) # revealed: Unknown + +def two(x: int | None, y: int | None): + assert x is None, (y := 42) * reveal_type(y) # revealed: Literal[42] + reveal_type(y) # revealed: int | None +``` + +## Assertions with `test` predicates that are statically known to always be `True` + +```py +assert True, (x := 1) + +# error: [unresolved-reference] +reveal_type(x) # revealed: Unknown + +assert False, (y := 1) + +# The `assert` statement is terminal if `test` resolves to `False`, +# so even though we know the `msg` branch will have been taken here +# (we know what the truthiness of `False is!), we also know that the +# `y` definition is not visible from this point in control flow +# (because this point in control flow is unreachable). +# We make sure that this does not emit an `[unresolved-reference]` +# diagnostic by adding a reachability constraint, +# but the inferred type is `Unknown`. +# +reveal_type(y) # revealed: Unknown +``` + +## Assertions with messages that reference definitions from the `test` + +```py +def one(x: int | None): + assert (y := x), reveal_type(y) # revealed: (int & ~AlwaysTruthy) | None + reveal_type(y) # revealed: int & ~AlwaysFalsy + +def two(x: int | None): + assert isinstance((y := x), int), reveal_type(y) # revealed: None + reveal_type(y) # revealed: int +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/unreachable.md b/crates/red_knot_python_semantic/resources/mdtest/unreachable.md index e01ecdd314..0348816b49 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/unreachable.md +++ b/crates/red_knot_python_semantic/resources/mdtest/unreachable.md @@ -362,6 +362,15 @@ def f(): ExceptionGroup ``` +Similarly, assertions with statically-known falsy conditions can lead to unreachable code: + +```py +def f(): + assert sys.version_info > (3, 11) + + ExceptionGroup +``` + Finally, not that anyone would ever use it, but it also works for `while` loops: ```py diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 17aff301a5..a633c2c5cf 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -532,11 +532,8 @@ impl<'db> SemanticIndexBuilder<'db> { /// Negates a predicate and adds it to the list of all predicates, does not record it. fn add_negated_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { - let negated = Predicate { - node: predicate.node, - is_positive: false, - }; - self.current_use_def_map_mut().add_predicate(negated) + self.current_use_def_map_mut() + .add_predicate(predicate.negated()) } /// Records a previously added narrowing constraint by adding it to all live bindings. @@ -1383,14 +1380,46 @@ where } } - ast::Stmt::Assert(node) => { - self.visit_expr(&node.test); - let predicate = self.record_expression_narrowing_constraint(&node.test); - self.record_visibility_constraint(predicate); + ast::Stmt::Assert(ast::StmtAssert { + test, + msg, + range: _, + }) => { + // We model an `assert test, msg` statement here. Conceptually, we can think of + // this as being equivalent to the following: + // + // ```py + // if not test: + // msg + // + // + // + // ``` + // + // Importantly, the `msg` expression is only evaluated if the `test` expression is + // falsy. This is why we apply the negated `test` predicate as a narrowing and + // reachability constraint on the `msg` expression. + // + // The other important part is the ``. This lets us skip the usual merging of + // flow states and simplification of visibility constraints, since there is no way + // of getting out of that `msg` branch. We simply restore to the post-test state. - if let Some(msg) = &node.msg { + self.visit_expr(test); + let predicate = self.build_predicate(test); + + if let Some(msg) = msg { + let post_test = self.flow_snapshot(); + let negated_predicate = predicate.negated(); + self.record_narrowing_constraint(negated_predicate); + self.record_reachability_constraint(negated_predicate); self.visit_expr(msg); + self.record_visibility_constraint(negated_predicate); + self.flow_restore(post_test); } + + self.record_narrowing_constraint(predicate); + self.record_visibility_constraint(predicate); + self.record_reachability_constraint(predicate); } ast::Stmt::Assign(node) => { diff --git a/crates/red_knot_python_semantic/src/semantic_index/predicate.rs b/crates/red_knot_python_semantic/src/semantic_index/predicate.rs index c2885022e8..1639aeaf5a 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/predicate.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/predicate.rs @@ -49,6 +49,15 @@ pub(crate) struct Predicate<'db> { pub(crate) is_positive: bool, } +impl Predicate<'_> { + pub(crate) fn negated(self) -> Self { + Self { + node: self.node, + is_positive: !self.is_positive, + } + } +} + #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update)] pub(crate) enum PredicateNode<'db> { Expression(Expression<'db>),