diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/while.md b/crates/ty_python_semantic/resources/mdtest/narrow/while.md index af8141b646..deae318666 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/while.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/while.md @@ -59,3 +59,17 @@ while x != 1: x = next_item() ``` + +## With `break` statements + +```py +def next_item() -> int | None: + return 1 + +while True: + x = next_item() + if x is not None: + break + +reveal_type(x) # revealed: int +``` diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 6a8e6d2a58..4e678ac8ef 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -35,8 +35,8 @@ use crate::semantic_index::place::{ PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId, }; use crate::semantic_index::predicate::{ - PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, ScopedPredicateId, - StarImportPlaceholderPredicate, + PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, PredicateOrLiteral, + ScopedPredicateId, StarImportPlaceholderPredicate, }; use crate::semantic_index::re_exports::exported_names; use crate::semantic_index::reachability_constraints::{ @@ -535,29 +535,34 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { fn record_expression_narrowing_constraint( &mut self, precide_node: &ast::Expr, - ) -> Predicate<'db> { + ) -> PredicateOrLiteral<'db> { let predicate = self.build_predicate(precide_node); self.record_narrowing_constraint(predicate); predicate } - fn build_predicate(&mut self, predicate_node: &ast::Expr) -> Predicate<'db> { + fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> { let expression = self.add_standalone_expression(predicate_node); - Predicate { - node: PredicateNode::Expression(expression), - is_positive: true, + + if let Some(boolean_literal) = predicate_node.as_boolean_literal_expr() { + PredicateOrLiteral::Literal(boolean_literal.value) + } else { + PredicateOrLiteral::Predicate(Predicate { + node: PredicateNode::Expression(expression), + is_positive: true, + }) } } /// Adds a new predicate to the list of all predicates, but does not record it. Returns the /// predicate ID for later recording using /// [`SemanticIndexBuilder::record_narrowing_constraint_id`]. - fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { + fn add_predicate(&mut self, predicate: PredicateOrLiteral<'db>) -> ScopedPredicateId { self.current_use_def_map_mut().add_predicate(predicate) } /// Negates a predicate and adds it to the list of all predicates, does not record it. - fn add_negated_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { + fn add_negated_predicate(&mut self, predicate: PredicateOrLiteral<'db>) -> ScopedPredicateId { self.current_use_def_map_mut() .add_predicate(predicate.negated()) } @@ -569,7 +574,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { } /// Adds and records a narrowing constraint, i.e. adds it to all live bindings. - fn record_narrowing_constraint(&mut self, predicate: Predicate<'db>) { + fn record_narrowing_constraint(&mut self, predicate: PredicateOrLiteral<'db>) { let use_def = self.current_use_def_map_mut(); let predicate_id = use_def.add_predicate(predicate); use_def.record_narrowing_constraint(predicate_id); @@ -579,7 +584,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { /// bindings. fn record_negated_narrowing_constraint( &mut self, - predicate: Predicate<'db>, + predicate: PredicateOrLiteral<'db>, ) -> ScopedPredicateId { let id = self.add_negated_predicate(predicate); self.record_narrowing_constraint_id(id); @@ -603,7 +608,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { /// we know that all statements that follow in this path of control flow will be unreachable. fn record_reachability_constraint( &mut self, - predicate: Predicate<'db>, + predicate: PredicateOrLiteral<'db>, ) -> ScopedReachabilityConstraintId { let predicate_id = self.add_predicate(predicate); self.record_reachability_constraint_id(predicate_id) @@ -617,6 +622,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { let reachability_constraint = self .current_reachability_constraints_mut() .add_atom(predicate_id); + self.current_use_def_map_mut() .record_reachability_constraint(reachability_constraint); reachability_constraint @@ -681,7 +687,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { subject: Expression<'db>, pattern: &ast::Pattern, guard: Option<&ast::Expr>, - ) -> Predicate<'db> { + ) -> PredicateOrLiteral<'db> { // This is called for the top-level pattern of each match arm. We need to create a // standalone expression for each arm of a match statement, since they can introduce // constraints on the match subject. (Or more accurately, for the match arm's pattern, @@ -705,10 +711,10 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { guard, countme::Count::default(), ); - let predicate = Predicate { + let predicate = PredicateOrLiteral::Predicate(Predicate { node: PredicateNode::Pattern(pattern_predicate), is_positive: true, - }; + }); self.record_narrowing_constraint(predicate); predicate } @@ -1653,10 +1659,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.record_ambiguous_reachability(); self.visit_expr(guard); let post_guard_eval = self.flow_snapshot(); - let predicate = Predicate { + let predicate = PredicateOrLiteral::Predicate(Predicate { node: PredicateNode::Expression(guard_expr), is_positive: true, - }; + }); self.record_negated_narrowing_constraint(predicate); let match_success_guard_failure = self.flow_snapshot(); self.flow_restore(post_guard_eval); diff --git a/crates/ty_python_semantic/src/semantic_index/predicate.rs b/crates/ty_python_semantic/src/semantic_index/predicate.rs index 6009914ac4..80096841e1 100644 --- a/crates/ty_python_semantic/src/semantic_index/predicate.rs +++ b/crates/ty_python_semantic/src/semantic_index/predicate.rs @@ -8,7 +8,7 @@ //! static reachability of a binding, and the reachability of a statement or expression. use ruff_db::files::File; -use ruff_index::{IndexVec, newtype_index}; +use ruff_index::{Idx, IndexVec}; use ruff_python_ast::Singleton; use crate::db::Db; @@ -17,9 +17,42 @@ use crate::semantic_index::global_scope; use crate::semantic_index::place::{FileScopeId, ScopeId, ScopedPlaceId}; // A scoped identifier for each `Predicate` in a scope. -#[newtype_index] -#[derive(Ord, PartialOrd, get_size2::GetSize)] -pub(crate) struct ScopedPredicateId; +#[derive(Clone, Debug, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, get_size2::GetSize)] +pub(crate) struct ScopedPredicateId(u32); + +impl ScopedPredicateId { + /// A special ID that is used for an "always true" predicate. + pub(crate) const ALWAYS_TRUE: ScopedPredicateId = ScopedPredicateId(0xffff_ffff); + + /// A special ID that is used for an "always false" predicate. + pub(crate) const ALWAYS_FALSE: ScopedPredicateId = ScopedPredicateId(0xffff_fffe); + + const SMALLEST_TERMINAL: ScopedPredicateId = Self::ALWAYS_FALSE; + + fn is_terminal(self) -> bool { + self >= Self::SMALLEST_TERMINAL + } + + #[cfg(test)] + pub(crate) fn as_u32(self) -> u32 { + self.0 + } +} + +impl Idx for ScopedPredicateId { + #[inline] + fn new(value: usize) -> Self { + assert!(value <= (Self::SMALLEST_TERMINAL.0 as usize)); + #[expect(clippy::cast_possible_truncation)] + Self(value as u32) + } + + #[inline] + fn index(self) -> usize { + debug_assert!(!self.is_terminal()); + self.0 as usize + } +} // A collection of predicates for a given scope. pub(crate) type Predicates<'db> = IndexVec>; @@ -49,11 +82,22 @@ pub(crate) struct Predicate<'db> { pub(crate) is_positive: bool, } -impl Predicate<'_> { +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +pub(crate) enum PredicateOrLiteral<'db> { + Literal(bool), + Predicate(Predicate<'db>), +} + +impl PredicateOrLiteral<'_> { pub(crate) fn negated(self) -> Self { - Self { - node: self.node, - is_positive: !self.is_positive, + match self { + PredicateOrLiteral::Literal(value) => PredicateOrLiteral::Literal(!value), + PredicateOrLiteral::Predicate(Predicate { node, is_positive }) => { + PredicateOrLiteral::Predicate(Predicate { + node, + is_positive: !is_positive, + }) + } } } } @@ -169,11 +213,11 @@ impl<'db> StarImportPlaceholderPredicate<'db> { } } -impl<'db> From> for Predicate<'db> { +impl<'db> From> for PredicateOrLiteral<'db> { fn from(predicate: StarImportPlaceholderPredicate<'db>) -> Self { - Predicate { + PredicateOrLiteral::Predicate(Predicate { node: PredicateNode::StarImportPlaceholder(predicate), is_positive: true, - } + }) } } diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index e7fd1d6914..e34ba33096 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -388,12 +388,18 @@ impl ReachabilityConstraintsBuilder { &mut self, predicate: ScopedPredicateId, ) -> ScopedReachabilityConstraintId { - self.add_interior(InteriorNode { - atom: predicate, - if_true: ALWAYS_TRUE, - if_ambiguous: AMBIGUOUS, - if_false: ALWAYS_FALSE, - }) + if predicate == ScopedPredicateId::ALWAYS_FALSE { + ScopedReachabilityConstraintId::ALWAYS_FALSE + } else if predicate == ScopedPredicateId::ALWAYS_TRUE { + ScopedReachabilityConstraintId::ALWAYS_TRUE + } else { + self.add_interior(InteriorNode { + atom: predicate, + if_true: ALWAYS_TRUE, + if_ambiguous: AMBIGUOUS, + if_false: ALWAYS_FALSE, + }) + } } /// Adds a new reachability constraint that is the ternary NOT of an existing one. diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 3e20b8dbf3..b17459c7aa 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -247,7 +247,8 @@ use crate::semantic_index::place::{ FileScopeId, PlaceExpr, PlaceExprWithFlags, ScopeKind, ScopedPlaceId, }; use crate::semantic_index::predicate::{ - Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate, + Predicate, PredicateOrLiteral, Predicates, PredicatesBuilder, ScopedPredicateId, + StarImportPlaceholderPredicate, }; use crate::semantic_index::reachability_constraints::{ ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId, @@ -805,11 +806,25 @@ impl<'db> UseDefMapBuilder<'db> { ); } - pub(super) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { - self.predicates.add_predicate(predicate) + pub(super) fn add_predicate( + &mut self, + predicate: PredicateOrLiteral<'db>, + ) -> ScopedPredicateId { + match predicate { + PredicateOrLiteral::Predicate(predicate) => self.predicates.add_predicate(predicate), + PredicateOrLiteral::Literal(true) => ScopedPredicateId::ALWAYS_TRUE, + PredicateOrLiteral::Literal(false) => ScopedPredicateId::ALWAYS_FALSE, + } } pub(super) fn record_narrowing_constraint(&mut self, predicate: ScopedPredicateId) { + if predicate == ScopedPredicateId::ALWAYS_TRUE + || predicate == ScopedPredicateId::ALWAYS_FALSE + { + // No need to record a narrowing constraint for `True` or `False`. + return; + } + let narrowing_constraint = predicate.into(); for state in &mut self.place_states { state diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs index b3f34577b9..116dbece85 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs @@ -431,6 +431,7 @@ impl PlaceState { #[cfg(test)] mod tests { use super::*; + use ruff_index::Idx; use crate::semantic_index::predicate::ScopedPredicateId; @@ -514,7 +515,7 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::from_u32(0).into(); + let predicate = ScopedPredicateId::new(0).into(); sym.record_narrowing_constraint(&mut narrowing_constraints, predicate); assert_bindings(&narrowing_constraints, &sym, &["1<0>"]); @@ -533,7 +534,7 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::from_u32(0).into(); + let predicate = ScopedPredicateId::new(0).into(); sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate); let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); @@ -543,7 +544,7 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::from_u32(0).into(); + let predicate = ScopedPredicateId::new(0).into(); sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym1a.merge( @@ -562,7 +563,7 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::from_u32(1).into(); + let predicate = ScopedPredicateId::new(1).into(); sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate); let mut sym1b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); @@ -572,7 +573,7 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::from_u32(2).into(); + let predicate = ScopedPredicateId::new(2).into(); sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); sym2a.merge( @@ -591,7 +592,7 @@ mod tests { false, true, ); - let predicate = ScopedPredicateId::from_u32(3).into(); + let predicate = ScopedPredicateId::new(3).into(); sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate); let sym2b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE);