diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md index 27b01efe7b..8fd2f7cfdd 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md @@ -39,8 +39,7 @@ match x: case A(): reveal_type(x) # revealed: A case B(): - # TODO could be `B & ~A` - reveal_type(x) # revealed: B + reveal_type(x) # revealed: B & ~A reveal_type(x) # revealed: object ``` @@ -88,7 +87,7 @@ match x: case 6.0: reveal_type(x) # revealed: float case 1j: - reveal_type(x) # revealed: complex + reveal_type(x) # revealed: complex & ~float case b"foo": reveal_type(x) # revealed: Literal[b"foo"] @@ -134,11 +133,11 @@ match x: case "foo" | 42 | None: reveal_type(x) # revealed: Literal["foo", 42] | None case "foo" | tuple(): - reveal_type(x) # revealed: Literal["foo"] | tuple + reveal_type(x) # revealed: tuple case True | False: reveal_type(x) # revealed: bool case 3.14 | 2.718 | 1.414: - reveal_type(x) # revealed: float + reveal_type(x) # revealed: float & ~tuple reveal_type(x) # revealed: object ``` @@ -165,3 +164,49 @@ match x: reveal_type(x) # revealed: object ``` + +## Narrowing due to guard + +```py +def get_object() -> object: + return object() + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case str() | float() if type(x) is str: + reveal_type(x) # revealed: str + case "foo" | 42 | None if isinstance(x, int): + reveal_type(x) # revealed: Literal[42] + case False if x: + reveal_type(x) # revealed: Never + case "foo" if x := "bar": + reveal_type(x) # revealed: Literal["bar"] + +reveal_type(x) # revealed: object +``` + +## Guard and reveal_type in guard + +```py +def get_object() -> object: + return object() + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case str() | float() if type(x) is str and reveal_type(x): # revealed: str + pass + case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42] + pass + case False if x and reveal_type(x): # revealed: Never + pass + case "foo" if (x := "bar") and reveal_type(x): # revealed: Literal["bar"] + pass + +reveal_type(x) # revealed: object +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index f2d2a30224..e4c25f4840 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1572,54 +1572,76 @@ where return; } - let after_subject = self.flow_snapshot(); - let mut vis_constraints = vec![]; - let mut post_case_snapshots = vec![]; - for (i, case) in cases.iter().enumerate() { - if i != 0 { - post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); - } + let mut no_case_matched = self.flow_snapshot(); + let has_catchall = cases + .last() + .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()); + + let mut post_case_snapshots = vec![]; + let mut match_predicate; + + for (i, case) in cases.iter().enumerate() { self.current_match_case = Some(CurrentMatchCase::new(&case.pattern)); self.visit_pattern(&case.pattern); self.current_match_case = None; - let predicate = self.add_pattern_narrowing_constraint( + // unlike in [Stmt::If], we don't reset [no_case_matched] + // here because the effects of visiting a pattern is binding + // symbols, and this doesn't occur unless the pattern + // actually matches + match_predicate = self.add_pattern_narrowing_constraint( subject_expr, &case.pattern, case.guard.as_deref(), ); - self.record_reachability_constraint(predicate); - if let Some(expr) = &case.guard { - self.visit_expr(expr); - } + let vis_constraint_id = self.record_reachability_constraint(match_predicate); + + let match_success_guard_failure = case.guard.as_ref().map(|guard| { + let guard_expr = self.add_standalone_expression(guard); + self.visit_expr(guard); + let post_guard_eval = self.flow_snapshot(); + let predicate = Predicate { + node: PredicateNode::Expression(guard_expr), + is_positive: true, + }; + self.record_negated_narrowing_constraint(predicate); + let match_success_guard_failure = self.flow_snapshot(); + self.flow_restore(post_guard_eval); + self.record_narrowing_constraint(predicate); + match_success_guard_failure + }); + + self.record_visibility_constraint_id(vis_constraint_id); + self.visit_body(&case.body); - for id in &vis_constraints { - self.record_negated_visibility_constraint(*id); - } - let vis_constraint_id = self.record_visibility_constraint(predicate); - vis_constraints.push(vis_constraint_id); - } - // If there is no final wildcard match case, pretend there is one. This is similar to how - // we add an implicit `else` block in if-elif chains, in case it's not present. - if !cases - .last() - .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()) - { post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); - for id in &vis_constraints { - self.record_negated_visibility_constraint(*id); + if i != cases.len() - 1 || !has_catchall { + // We need to restore the state after each case, but not after the last + // one. The last one will just become the state that we merge the other + // snapshots into. + self.flow_restore(no_case_matched.clone()); + self.record_negated_narrowing_constraint(match_predicate); + if let Some(match_success_guard_failure) = match_success_guard_failure { + self.flow_merge(match_success_guard_failure); + } else { + assert!(case.guard.is_none()); + } + } else { + debug_assert!(match_success_guard_failure.is_none()); + debug_assert!(case.guard.is_none()); } + + self.record_negated_visibility_constraint(vis_constraint_id); + no_case_matched = self.flow_snapshot(); } for post_clause_state in post_case_snapshots { self.flow_merge(post_clause_state); } - self.simplify_visibility_constraints(after_subject); + self.simplify_visibility_constraints(no_case_matched); } ast::Stmt::Try(ast::StmtTry { body, diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 04ca2ead84..50cfb9b931 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -50,7 +50,13 @@ pub(crate) fn infer_narrowing_constraint<'db>( all_negative_narrowing_constraints_for_expression(db, expression) } } - PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern), + PredicateNode::Pattern(pattern) => { + if predicate.is_positive { + all_narrowing_constraints_for_pattern(db, pattern) + } else { + all_negative_narrowing_constraints_for_pattern(db, pattern) + } + } PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { @@ -95,6 +101,15 @@ fn all_negative_narrowing_constraints_for_expression<'db>( NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish() } +#[allow(clippy::ref_option)] +#[salsa::tracked(return_ref)] +fn all_negative_narrowing_constraints_for_pattern<'db>( + db: &'db dyn Db, + pattern: PatternPredicate<'db>, +) -> Option> { + NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish() +} + #[allow(clippy::ref_option)] fn constraints_for_expression_cycle_recover<'db>( _db: &'db dyn Db, @@ -217,6 +232,12 @@ fn merge_constraints_or<'db>( } } +fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) { + for (_symbol, ty) in constraints.iter_mut() { + *ty = ty.negate_if(db, yes); + } +} + struct NarrowingConstraintsBuilder<'db> { db: &'db dyn Db, predicate: PredicateNode<'db>, @@ -237,7 +258,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { PredicateNode::Expression(expression) => { self.evaluate_expression_predicate(expression, self.is_positive) } - PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern), + PredicateNode::Pattern(pattern) => { + self.evaluate_pattern_predicate(pattern, self.is_positive) + } PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(mut constraints) = constraints { @@ -301,10 +324,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { fn evaluate_pattern_predicate( &mut self, pattern: PatternPredicate<'db>, + is_positive: bool, ) -> Option> { let subject = pattern.subject(self.db); - self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject) + .map(|mut constraints| { + negate_if(&mut constraints, self.db, !is_positive); + constraints + }) } fn symbols(&self) -> Arc {