From de8f4e62e274dfaf674fc981045a9abbe15eab44 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Thu, 17 Apr 2025 21:18:34 -0400 Subject: [PATCH] [red-knot] more type-narrowing in match statements (#17302) ## Summary Add more narrowing analysis for match statements: * add narrowing constraints from guard expressions * add negated constraints from previous predicates and guards to subsequent cases This PR doesn't address that guards can mutate your subject, and so theoretically invalidate some of these narrowing constraints that you've previously accumulated. Some prior art on this issue [here][mutable guards]. [mutable guards]: https://www.irif.fr/~scherer/research/mutable-patterns/mutable-patterns-mlworkshop2024-abstract.pdf ## Test Plan Add some new tests, and update some existing ones --------- Co-authored-by: Carl Meyer --- .../resources/mdtest/narrow/match.md | 55 +++++++++++-- .../src/semantic_index/builder.rs | 80 ++++++++++++------- .../src/types/narrow.rs | 33 +++++++- 3 files changed, 131 insertions(+), 37 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md index 27b01efe7b..8fd2f7cfdd 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md @@ -39,8 +39,7 @@ match x: case A(): reveal_type(x) # revealed: A case B(): - # TODO could be `B & ~A` - reveal_type(x) # revealed: B + reveal_type(x) # revealed: B & ~A reveal_type(x) # revealed: object ``` @@ -88,7 +87,7 @@ match x: case 6.0: reveal_type(x) # revealed: float case 1j: - reveal_type(x) # revealed: complex + reveal_type(x) # revealed: complex & ~float case b"foo": reveal_type(x) # revealed: Literal[b"foo"] @@ -134,11 +133,11 @@ match x: case "foo" | 42 | None: reveal_type(x) # revealed: Literal["foo", 42] | None case "foo" | tuple(): - reveal_type(x) # revealed: Literal["foo"] | tuple + reveal_type(x) # revealed: tuple case True | False: reveal_type(x) # revealed: bool case 3.14 | 2.718 | 1.414: - reveal_type(x) # revealed: float + reveal_type(x) # revealed: float & ~tuple reveal_type(x) # revealed: object ``` @@ -165,3 +164,49 @@ match x: reveal_type(x) # revealed: object ``` + +## Narrowing due to guard + +```py +def get_object() -> object: + return object() + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case str() | float() if type(x) is str: + reveal_type(x) # revealed: str + case "foo" | 42 | None if isinstance(x, int): + reveal_type(x) # revealed: Literal[42] + case False if x: + reveal_type(x) # revealed: Never + case "foo" if x := "bar": + reveal_type(x) # revealed: Literal["bar"] + +reveal_type(x) # revealed: object +``` + +## Guard and reveal_type in guard + +```py +def get_object() -> object: + return object() + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case str() | float() if type(x) is str and reveal_type(x): # revealed: str + pass + case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42] + pass + case False if x and reveal_type(x): # revealed: Never + pass + case "foo" if (x := "bar") and reveal_type(x): # revealed: Literal["bar"] + pass + +reveal_type(x) # revealed: object +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index f2d2a30224..e4c25f4840 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1572,54 +1572,76 @@ where return; } - let after_subject = self.flow_snapshot(); - let mut vis_constraints = vec![]; - let mut post_case_snapshots = vec![]; - for (i, case) in cases.iter().enumerate() { - if i != 0 { - post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); - } + let mut no_case_matched = self.flow_snapshot(); + let has_catchall = cases + .last() + .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()); + + let mut post_case_snapshots = vec![]; + let mut match_predicate; + + for (i, case) in cases.iter().enumerate() { self.current_match_case = Some(CurrentMatchCase::new(&case.pattern)); self.visit_pattern(&case.pattern); self.current_match_case = None; - let predicate = self.add_pattern_narrowing_constraint( + // unlike in [Stmt::If], we don't reset [no_case_matched] + // here because the effects of visiting a pattern is binding + // symbols, and this doesn't occur unless the pattern + // actually matches + match_predicate = self.add_pattern_narrowing_constraint( subject_expr, &case.pattern, case.guard.as_deref(), ); - self.record_reachability_constraint(predicate); - if let Some(expr) = &case.guard { - self.visit_expr(expr); - } + let vis_constraint_id = self.record_reachability_constraint(match_predicate); + + let match_success_guard_failure = case.guard.as_ref().map(|guard| { + let guard_expr = self.add_standalone_expression(guard); + self.visit_expr(guard); + let post_guard_eval = self.flow_snapshot(); + let predicate = Predicate { + node: PredicateNode::Expression(guard_expr), + is_positive: true, + }; + self.record_negated_narrowing_constraint(predicate); + let match_success_guard_failure = self.flow_snapshot(); + self.flow_restore(post_guard_eval); + self.record_narrowing_constraint(predicate); + match_success_guard_failure + }); + + self.record_visibility_constraint_id(vis_constraint_id); + self.visit_body(&case.body); - for id in &vis_constraints { - self.record_negated_visibility_constraint(*id); - } - let vis_constraint_id = self.record_visibility_constraint(predicate); - vis_constraints.push(vis_constraint_id); - } - // If there is no final wildcard match case, pretend there is one. This is similar to how - // we add an implicit `else` block in if-elif chains, in case it's not present. - if !cases - .last() - .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()) - { post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); - for id in &vis_constraints { - self.record_negated_visibility_constraint(*id); + if i != cases.len() - 1 || !has_catchall { + // We need to restore the state after each case, but not after the last + // one. The last one will just become the state that we merge the other + // snapshots into. + self.flow_restore(no_case_matched.clone()); + self.record_negated_narrowing_constraint(match_predicate); + if let Some(match_success_guard_failure) = match_success_guard_failure { + self.flow_merge(match_success_guard_failure); + } else { + assert!(case.guard.is_none()); + } + } else { + debug_assert!(match_success_guard_failure.is_none()); + debug_assert!(case.guard.is_none()); } + + self.record_negated_visibility_constraint(vis_constraint_id); + no_case_matched = self.flow_snapshot(); } for post_clause_state in post_case_snapshots { self.flow_merge(post_clause_state); } - self.simplify_visibility_constraints(after_subject); + self.simplify_visibility_constraints(no_case_matched); } ast::Stmt::Try(ast::StmtTry { body, diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 04ca2ead84..50cfb9b931 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -50,7 +50,13 @@ pub(crate) fn infer_narrowing_constraint<'db>( all_negative_narrowing_constraints_for_expression(db, expression) } } - PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern), + PredicateNode::Pattern(pattern) => { + if predicate.is_positive { + all_narrowing_constraints_for_pattern(db, pattern) + } else { + all_negative_narrowing_constraints_for_pattern(db, pattern) + } + } PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { @@ -95,6 +101,15 @@ fn all_negative_narrowing_constraints_for_expression<'db>( NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish() } +#[allow(clippy::ref_option)] +#[salsa::tracked(return_ref)] +fn all_negative_narrowing_constraints_for_pattern<'db>( + db: &'db dyn Db, + pattern: PatternPredicate<'db>, +) -> Option> { + NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish() +} + #[allow(clippy::ref_option)] fn constraints_for_expression_cycle_recover<'db>( _db: &'db dyn Db, @@ -217,6 +232,12 @@ fn merge_constraints_or<'db>( } } +fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) { + for (_symbol, ty) in constraints.iter_mut() { + *ty = ty.negate_if(db, yes); + } +} + struct NarrowingConstraintsBuilder<'db> { db: &'db dyn Db, predicate: PredicateNode<'db>, @@ -237,7 +258,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { PredicateNode::Expression(expression) => { self.evaluate_expression_predicate(expression, self.is_positive) } - PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern), + PredicateNode::Pattern(pattern) => { + self.evaluate_pattern_predicate(pattern, self.is_positive) + } PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(mut constraints) = constraints { @@ -301,10 +324,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { fn evaluate_pattern_predicate( &mut self, pattern: PatternPredicate<'db>, + is_positive: bool, ) -> Option> { let subject = pattern.subject(self.db); - self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject) + .map(|mut constraints| { + negate_if(&mut constraints, self.db, !is_positive); + constraints + }) } fn symbols(&self) -> Arc {