diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md index ab026cef67..6b001ea094 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md @@ -42,3 +42,135 @@ def _(flag1: bool, flag2: bool): else: reveal_type(x) # revealed: Never ``` + +## Cross-scope narrowing + +Narrowing constraints are also valid in eager nested scopes (however, because class variables are +not visible from nested scopes, constraints on those variables are invalid). + +Currently they are assumed to be invalid in lazy nested scopes since there is a possibility that the +constraints may no longer be valid due to a "time lag". However, it may be possible to determine +that some of them are valid by performing a more detailed analysis (e.g. checking that the narrowing +target has not changed in all places where the function is called). + +### Narrowing constraints introduced in eager nested scopes + +```py +g: str | None = "a" + +def f(x: str | None): + def _(): + if x is not None: + reveal_type(x) # revealed: str + + if not isinstance(x, str): + reveal_type(x) # revealed: None + + if g is not None: + reveal_type(g) # revealed: str + + class C: + if x is not None: + reveal_type(x) # revealed: str + + if not isinstance(x, str): + reveal_type(x) # revealed: None + + if g is not None: + reveal_type(g) # revealed: str + + # TODO: should be str + # This could be fixed if we supported narrowing with if clauses in comprehensions. + [reveal_type(x) for _ in range(1) if x is not None] # revealed: str | None +``` + +### Narrowing constraints introduced in the outer scope + +```py +g: str | None = "a" + +def f(x: str | None): + if x is not None: + def _(): + # If there is a possibility that `x` may be rewritten after this function definition, + # the constraint `x is not None` outside the function is no longer be applicable for narrowing. + reveal_type(x) # revealed: str | None + + class C: + reveal_type(x) # revealed: str + + [reveal_type(x) for _ in range(1)] # revealed: str + + if g is not None: + def _(): + reveal_type(g) # revealed: str | None + + class D: + reveal_type(g) # revealed: str + + [reveal_type(g) for _ in range(1)] # revealed: str +``` + +### Narrowing constraints introduced in multiple scopes + +```py +from typing import Literal + +g: str | Literal[1] | None = "a" + +def f(x: str | Literal[1] | None): + class C: + if x is not None: + def _(): + if x != 1: + reveal_type(x) # revealed: str | None + + class D: + if x != 1: + reveal_type(x) # revealed: str + + # TODO: should be str + [reveal_type(x) for _ in range(1) if x != 1] # revealed: str | Literal[1] + + if g is not None: + def _(): + if g != 1: + reveal_type(g) # revealed: str | None + + class D: + if g != 1: + reveal_type(g) # revealed: str +``` + +### Narrowing constraints with bindings in class scope, and nested scopes + +```py +from typing import Literal + +g: str | Literal[1] | None = "a" + +def f(flag: bool): + class C: + (g := None) if flag else (g := None) + # `g` is always bound here, so narrowing checks don't apply to nested scopes + if g is not None: + class F: + reveal_type(g) # revealed: str | Literal[1] | None + + class C: + # this conditional binding leaves "unbound" visible, so following narrowing checks apply + None if flag else (g := None) + + if g is not None: + class F: + reveal_type(g) # revealed: str | Literal[1] + + # This class variable is not visible from the nested class scope. + g = None + + # This additional constraint is not relevant to nested scopes, since it only applies to + # a binding of `g` that they cannot see: + if g is None: + class E: + reveal_type(g) # revealed: str | Literal[1] +``` diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index 67f9845326..42956f0bfe 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -17,18 +17,19 @@ use crate::semantic_index::ast_ids::AstIds; use crate::semantic_index::builder::SemanticIndexBuilder; use crate::semantic_index::definition::{Definition, DefinitionNodeKey, Definitions}; use crate::semantic_index::expression::Expression; +use crate::semantic_index::narrowing_constraints::ScopedNarrowingConstraint; use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable, }; -use crate::semantic_index::use_def::{EagerBindingsKey, ScopedEagerBindingsId, UseDefMap}; +use crate::semantic_index::use_def::{EagerSnapshotKey, ScopedEagerSnapshotId, UseDefMap}; use crate::Db; pub mod ast_ids; mod builder; pub mod definition; pub mod expression; -mod narrowing_constraints; +pub(crate) mod narrowing_constraints; pub(crate) mod predicate; mod re_exports; pub mod symbol; @@ -141,8 +142,9 @@ pub(crate) fn global_scope(db: &dyn Db, file: File) -> ScopeId<'_> { FileScopeId::global().to_scope_id(db, file) } -pub(crate) enum EagerBindingsResult<'map, 'db> { - Found(BindingWithConstraintsIterator<'map, 'db>), +pub(crate) enum EagerSnapshotResult<'map, 'db> { + FoundConstraint(ScopedNarrowingConstraint), + FoundBindings(BindingWithConstraintsIterator<'map, 'db>), NotFound, NoLongerInEagerContext, } @@ -189,8 +191,8 @@ pub(crate) struct SemanticIndex<'db> { /// Flags about the global scope (code usage impacting inference) has_future_annotations: bool, - /// Map of all of the eager bindings that appear in this file. - eager_bindings: FxHashMap, + /// Map of all of the eager snapshots that appear in this file. + eager_snapshots: FxHashMap, /// List of all semantic syntax errors in this file. semantic_syntax_errors: Vec, @@ -390,36 +392,34 @@ impl<'db> SemanticIndex<'db> { /// * `NoLongerInEagerContext` if the nested scope is no longer in an eager context /// (that is, not every scope that will be traversed is eager). /// * an iterator of bindings for a particular nested eager scope reference if the bindings exist. - /// * `NotFound` if the bindings do not exist in the nested eager scope. - pub(crate) fn eager_bindings( + /// * a narrowing constraint if there are no bindings, but there is a narrowing constraint for an outer scope symbol. + /// * `NotFound` if the narrowing constraint / bindings do not exist in the nested eager scope. + pub(crate) fn eager_snapshot( &self, enclosing_scope: FileScopeId, symbol: &str, nested_scope: FileScopeId, - ) -> EagerBindingsResult<'_, 'db> { + ) -> EagerSnapshotResult<'_, 'db> { for (ancestor_scope_id, ancestor_scope) in self.ancestor_scopes(nested_scope) { if ancestor_scope_id == enclosing_scope { break; } if !ancestor_scope.is_eager() { - return EagerBindingsResult::NoLongerInEagerContext; + return EagerSnapshotResult::NoLongerInEagerContext; } } let Some(symbol_id) = self.symbol_tables[enclosing_scope].symbol_id_by_name(symbol) else { - return EagerBindingsResult::NotFound; + return EagerSnapshotResult::NotFound; }; - let key = EagerBindingsKey { + let key = EagerSnapshotKey { enclosing_scope, enclosing_symbol: symbol_id, nested_scope, }; - let Some(id) = self.eager_bindings.get(&key) else { - return EagerBindingsResult::NotFound; + let Some(id) = self.eager_snapshots.get(&key) else { + return EagerSnapshotResult::NotFound; }; - match self.use_def_maps[enclosing_scope].eager_bindings(*id) { - Some(bindings) => EagerBindingsResult::Found(bindings), - None => EagerBindingsResult::NotFound, - } + self.use_def_maps[enclosing_scope].eager_snapshot(*id) } pub(crate) fn semantic_syntax_errors(&self) -> &[SemanticSyntaxError] { diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 215e3fbb42..34fed4fcfe 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -42,7 +42,7 @@ use crate::semantic_index::symbol::{ ScopedSymbolId, SymbolTableBuilder, }; use crate::semantic_index::use_def::{ - EagerBindingsKey, FlowSnapshot, ScopedEagerBindingsId, UseDefMapBuilder, + EagerSnapshotKey, FlowSnapshot, ScopedEagerSnapshotId, UseDefMapBuilder, }; use crate::semantic_index::visibility_constraints::{ ScopedVisibilityConstraintId, VisibilityConstraintsBuilder, @@ -113,7 +113,7 @@ pub(super) struct SemanticIndexBuilder<'db> { /// /// [generator functions]: https://docs.python.org/3/glossary.html#term-generator generator_functions: FxHashSet, - eager_bindings: FxHashMap, + eager_snapshots: FxHashMap, /// Errors collected by the `semantic_checker`. semantic_syntax_errors: RefCell>, } @@ -148,7 +148,7 @@ impl<'db> SemanticIndexBuilder<'db> { imported_modules: FxHashSet::default(), generator_functions: FxHashSet::default(), - eager_bindings: FxHashMap::default(), + eager_snapshots: FxHashMap::default(), python_version: Program::get(db).python_version(db), source_text: OnceCell::new(), @@ -253,13 +253,15 @@ impl<'db> SemanticIndexBuilder<'db> { children_start..children_start, reachability, ); + let is_class_scope = scope.kind().is_class(); self.try_node_context_stack_manager.enter_nested_scope(); let file_scope_id = self.scopes.push(scope); self.symbol_tables.push(SymbolTableBuilder::default()); self.instance_attribute_tables .push(SymbolTableBuilder::default()); - self.use_def_maps.push(UseDefMapBuilder::default()); + self.use_def_maps + .push(UseDefMapBuilder::new(is_class_scope)); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::default()); let scope_id = ScopeId::new(self.db, self.file, file_scope_id, countme::Count::default()); @@ -303,12 +305,6 @@ impl<'db> SemanticIndexBuilder<'db> { let enclosing_scope_kind = self.scopes[enclosing_scope_id].kind(); let enclosing_symbol_table = &self.symbol_tables[enclosing_scope_id]; - // Names bound in class scopes are never visible to nested scopes, so we never need to - // save eager scope bindings in a class scope. - if enclosing_scope_kind.is_class() { - continue; - } - for nested_symbol in self.symbol_tables[popped_scope_id].symbols() { // Skip this symbol if this enclosing scope doesn't contain any bindings for it. // Note that even if this symbol is bound in the popped scope, @@ -321,24 +317,26 @@ impl<'db> SemanticIndexBuilder<'db> { continue; }; let enclosing_symbol = enclosing_symbol_table.symbol(enclosing_symbol_id); - if !enclosing_symbol.is_bound() { - continue; - } - // Snapshot the bindings of this symbol that are visible at this point in this + // Snapshot the state of this symbol that are visible at this point in this // enclosing scope. - let key = EagerBindingsKey { + let key = EagerSnapshotKey { enclosing_scope: enclosing_scope_id, enclosing_symbol: enclosing_symbol_id, nested_scope: popped_scope_id, }; - let eager_bindings = self.use_def_maps[enclosing_scope_id] - .snapshot_eager_bindings(enclosing_symbol_id); - self.eager_bindings.insert(key, eager_bindings); + let eager_snapshot = self.use_def_maps[enclosing_scope_id].snapshot_eager_state( + enclosing_symbol_id, + enclosing_scope_kind, + enclosing_symbol.is_bound(), + ); + self.eager_snapshots.insert(key, eager_snapshot); } // Lazy scopes are "sticky": once we see a lazy scope we stop doing lookups // eagerly, even if we would encounter another eager enclosing scope later on. + // Also, narrowing constraints outside a lazy scope are not applicable. + // TODO: If the symbol has never been rewritten, they are applicable. if !enclosing_scope_kind.is_eager() { break; } @@ -1085,8 +1083,8 @@ impl<'db> SemanticIndexBuilder<'db> { self.scope_ids_by_scope.shrink_to_fit(); self.scopes_by_node.shrink_to_fit(); - self.eager_bindings.shrink_to_fit(); self.generator_functions.shrink_to_fit(); + self.eager_snapshots.shrink_to_fit(); SemanticIndex { symbol_tables, @@ -1101,7 +1099,7 @@ impl<'db> SemanticIndexBuilder<'db> { use_def_maps, imported_modules: Arc::new(self.imported_modules), has_future_annotations: self.has_future_annotations, - eager_bindings: self.eager_bindings, + eager_snapshots: self.eager_snapshots, semantic_syntax_errors: self.semantic_syntax_errors.into_inner(), generator_functions: self.generator_functions, } diff --git a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs index 6ed80e7ebf..83bfb0d25d 100644 --- a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs @@ -29,6 +29,7 @@ //! [`Predicate`]: crate::semantic_index::predicate::Predicate use crate::list::{List, ListBuilder, ListSetReverseIterator, ListStorage}; +use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::predicate::ScopedPredicateId; /// A narrowing constraint associated with a live binding. @@ -38,6 +39,12 @@ use crate::semantic_index::predicate::ScopedPredicateId; /// [`Predicate`]: crate::semantic_index::predicate::Predicate pub(crate) type ScopedNarrowingConstraint = List; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum ConstraintKey { + NarrowingConstraint(ScopedNarrowingConstraint), + UseId(ScopedUseId), +} + /// One of the [`Predicate`]s in a narrowing constraint, which constraints the type of the /// binding's symbol. /// diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index c9e62422db..c53c8eb468 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -259,25 +259,25 @@ use ruff_index::{newtype_index, IndexVec}; use rustc_hash::FxHashMap; -use self::symbol_state::ScopedDefinitionId; use self::symbol_state::{ - LiveBindingsIterator, LiveDeclaration, LiveDeclarationsIterator, SymbolBindings, - SymbolDeclarations, SymbolState, + EagerSnapshot, LiveBindingsIterator, LiveDeclaration, LiveDeclarationsIterator, + ScopedDefinitionId, SymbolBindings, SymbolDeclarations, SymbolState, }; use crate::node_key::NodeKey; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::definition::Definition; use crate::semantic_index::narrowing_constraints::{ - NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator, + ConstraintKey, NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator, }; use crate::semantic_index::predicate::{ Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate, }; -use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId}; +use crate::semantic_index::symbol::{FileScopeId, ScopeKind, ScopedSymbolId}; use crate::semantic_index::visibility_constraints::{ ScopedVisibilityConstraintId, VisibilityConstraints, VisibilityConstraintsBuilder, }; -use crate::types::Truthiness; +use crate::semantic_index::EagerSnapshotResult; +use crate::types::{infer_narrowing_constraint, IntersectionBuilder, Truthiness, Type}; mod symbol_state; @@ -328,7 +328,7 @@ pub(crate) struct UseDefMap<'db> { /// Snapshot of bindings in this scope that can be used to resolve a reference in a nested /// eager scope. - eager_bindings: EagerBindings, + eager_snapshots: EagerSnapshots, /// Whether or not the start of the scope is visible. /// This is used to check if the function can implicitly return `None`. @@ -354,6 +354,22 @@ impl<'db> UseDefMap<'db> { self.bindings_iterator(&self.bindings_by_use[use_id]) } + pub(crate) fn narrowing_constraints_at_use( + &self, + constraint_key: ConstraintKey, + ) -> ConstraintsIterator<'_, 'db> { + let constraint = match constraint_key { + ConstraintKey::NarrowingConstraint(constraint) => constraint, + ConstraintKey::UseId(use_id) => { + self.bindings_by_use[use_id].unbound_narrowing_constraint() + } + }; + ConstraintsIterator { + predicates: &self.predicates, + constraint_ids: self.narrowing_constraints.iter_predicates(constraint), + } + } + pub(super) fn is_reachable( &self, db: &dyn crate::Db, @@ -398,13 +414,19 @@ impl<'db> UseDefMap<'db> { self.bindings_iterator(self.instance_attributes[symbol].bindings()) } - pub(crate) fn eager_bindings( + pub(crate) fn eager_snapshot( &self, - eager_bindings: ScopedEagerBindingsId, - ) -> Option> { - self.eager_bindings - .get(eager_bindings) - .map(|symbol_bindings| self.bindings_iterator(symbol_bindings)) + eager_bindings: ScopedEagerSnapshotId, + ) -> EagerSnapshotResult<'_, 'db> { + match self.eager_snapshots.get(eager_bindings) { + Some(EagerSnapshot::Constraint(constraint)) => { + EagerSnapshotResult::FoundConstraint(*constraint) + } + Some(EagerSnapshot::Bindings(symbol_bindings)) => { + EagerSnapshotResult::FoundBindings(self.bindings_iterator(symbol_bindings)) + } + None => EagerSnapshotResult::NotFound, + } } pub(crate) fn bindings_at_declaration( @@ -489,19 +511,19 @@ impl<'db> UseDefMap<'db> { } } -/// Uniquely identifies a snapshot of bindings that can be used to resolve a reference in a nested -/// eager scope. +/// Uniquely identifies a snapshot of a symbol state that can be used to resolve a reference in a +/// nested eager scope. /// /// An eager scope has its entire body executed immediately at the location where it is defined. /// For any free references in the nested scope, we use the bindings that are visible at the point /// where the nested scope is defined, instead of using the public type of the symbol. /// -/// There is a unique ID for each distinct [`EagerBindingsKey`] in the file. +/// There is a unique ID for each distinct [`EagerSnapshotKey`] in the file. #[newtype_index] -pub(crate) struct ScopedEagerBindingsId; +pub(crate) struct ScopedEagerSnapshotId; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub(crate) struct EagerBindingsKey { +pub(crate) struct EagerSnapshotKey { /// The enclosing scope containing the bindings pub(crate) enclosing_scope: FileScopeId, /// The referenced symbol (in the enclosing scope) @@ -510,8 +532,8 @@ pub(crate) struct EagerBindingsKey { pub(crate) nested_scope: FileScopeId, } -/// A snapshot of bindings that can be used to resolve a reference in a nested eager scope. -type EagerBindings = IndexVec; +/// A snapshot of symbol states that can be used to resolve a reference in a nested eager scope. +type EagerSnapshots = IndexVec; #[derive(Debug)] pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { @@ -568,6 +590,33 @@ impl<'db> Iterator for ConstraintsIterator<'_, 'db> { impl std::iter::FusedIterator for ConstraintsIterator<'_, '_> {} +impl<'db> ConstraintsIterator<'_, 'db> { + pub(crate) fn narrow( + self, + db: &'db dyn crate::Db, + base_ty: Type<'db>, + symbol: ScopedSymbolId, + ) -> Type<'db> { + let constraint_tys: Vec<_> = self + .filter_map(|constraint| infer_narrowing_constraint(db, constraint, symbol)) + .collect(); + + if constraint_tys.is_empty() { + base_ty + } else { + let intersection_ty = constraint_tys + .into_iter() + .rev() + .fold( + IntersectionBuilder::new(db).add_positive(base_ty), + IntersectionBuilder::add_positive, + ) + .build(); + intersection_ty + } + } +} + #[derive(Clone)] pub(crate) struct DeclarationsIterator<'map, 'db> { all_definitions: &'map IndexVec>>, @@ -688,13 +737,16 @@ pub(super) struct UseDefMapBuilder<'db> { /// Currently live bindings for each instance attribute. instance_attribute_states: IndexVec, - /// Snapshot of bindings in this scope that can be used to resolve a reference in a nested - /// eager scope. - eager_bindings: EagerBindings, + /// Snapshots of symbol states in this scope that can be used to resolve a reference in a + /// nested eager scope. + eager_snapshots: EagerSnapshots, + + /// Is this a class scope? + is_class_scope: bool, } -impl Default for UseDefMapBuilder<'_> { - fn default() -> Self { +impl<'db> UseDefMapBuilder<'db> { + pub(super) fn new(is_class_scope: bool) -> Self { Self { all_definitions: IndexVec::from_iter([None]), predicates: PredicatesBuilder::default(), @@ -707,13 +759,11 @@ impl Default for UseDefMapBuilder<'_> { declarations_by_binding: FxHashMap::default(), bindings_by_declaration: FxHashMap::default(), symbol_states: IndexVec::new(), - eager_bindings: EagerBindings::default(), + eager_snapshots: EagerSnapshots::default(), instance_attribute_states: IndexVec::new(), + is_class_scope, } } -} - -impl<'db> UseDefMapBuilder<'db> { pub(super) fn mark_unreachable(&mut self) { self.record_visibility_constraint(ScopedVisibilityConstraintId::ALWAYS_FALSE); self.reachability = ScopedVisibilityConstraintId::ALWAYS_FALSE; @@ -738,7 +788,7 @@ impl<'db> UseDefMapBuilder<'db> { let symbol_state = &mut self.symbol_states[symbol]; self.declarations_by_binding .insert(binding, symbol_state.declarations().clone()); - symbol_state.record_binding(def_id, self.scope_start_visibility); + symbol_state.record_binding(def_id, self.scope_start_visibility, self.is_class_scope); } pub(super) fn record_attribute_binding( @@ -750,7 +800,7 @@ impl<'db> UseDefMapBuilder<'db> { let attribute_state = &mut self.instance_attribute_states[symbol]; self.declarations_by_binding .insert(binding, attribute_state.declarations().clone()); - attribute_state.record_binding(def_id, self.scope_start_visibility); + attribute_state.record_binding(def_id, self.scope_start_visibility, self.is_class_scope); } pub(super) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId { @@ -936,7 +986,7 @@ impl<'db> UseDefMapBuilder<'db> { let def_id = self.all_definitions.push(Some(definition)); let symbol_state = &mut self.symbol_states[symbol]; symbol_state.record_declaration(def_id); - symbol_state.record_binding(def_id, self.scope_start_visibility); + symbol_state.record_binding(def_id, self.scope_start_visibility, self.is_class_scope); } pub(super) fn record_use( @@ -961,12 +1011,25 @@ impl<'db> UseDefMapBuilder<'db> { self.node_reachability.insert(node_key, self.reachability); } - pub(super) fn snapshot_eager_bindings( + pub(super) fn snapshot_eager_state( &mut self, enclosing_symbol: ScopedSymbolId, - ) -> ScopedEagerBindingsId { - self.eager_bindings - .push(self.symbol_states[enclosing_symbol].bindings().clone()) + scope: ScopeKind, + is_bound: bool, + ) -> ScopedEagerSnapshotId { + // Names bound in class scopes are never visible to nested scopes, so we never need to + // save eager scope bindings in a class scope. + if scope.is_class() || !is_bound { + self.eager_snapshots.push(EagerSnapshot::Constraint( + self.symbol_states[enclosing_symbol] + .bindings() + .unbound_narrowing_constraint(), + )) + } else { + self.eager_snapshots.push(EagerSnapshot::Bindings( + self.symbol_states[enclosing_symbol].bindings().clone(), + )) + } } /// Take a snapshot of the current visible-symbols state. @@ -1086,7 +1149,7 @@ impl<'db> UseDefMapBuilder<'db> { self.node_reachability.shrink_to_fit(); self.declarations_by_binding.shrink_to_fit(); self.bindings_by_declaration.shrink_to_fit(); - self.eager_bindings.shrink_to_fit(); + self.eager_snapshots.shrink_to_fit(); UseDefMap { all_definitions: self.all_definitions, @@ -1099,7 +1162,7 @@ impl<'db> UseDefMapBuilder<'db> { instance_attributes: self.instance_attribute_states, declarations_by_binding: self.declarations_by_binding, bindings_by_declaration: self.bindings_by_declaration, - eager_bindings: self.eager_bindings, + eager_snapshots: self.eager_snapshots, scope_start_visibility: self.scope_start_visibility, } } diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/symbol_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/symbol_state.rs index 6807baf8e0..02c4e3e682 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/symbol_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/symbol_state.rs @@ -65,6 +65,10 @@ impl ScopedDefinitionId { /// When creating a use-def-map builder, we always add an empty `None` definition /// at index 0, so this ID is always present. pub(super) const UNBOUND: ScopedDefinitionId = ScopedDefinitionId::from_u32(0); + + fn is_unbound(self) -> bool { + self == Self::UNBOUND + } } /// Can keep inline this many live bindings or declarations per symbol at a given time; more will @@ -177,14 +181,41 @@ impl SymbolDeclarations { } } +/// A snapshot of a symbol state that can be used to resolve a reference in a nested eager scope. +/// If there are bindings in a (non-class) scope , they are stored in `Bindings`. +/// Even if it's a class scope (class variables are not visible to nested scopes) or there are no +/// bindings, the current narrowing constraint is necessary for narrowing, so it's stored in +/// `Constraint`. +#[derive(Clone, Debug, PartialEq, Eq, salsa::Update)] +pub(super) enum EagerSnapshot { + Constraint(ScopedNarrowingConstraint), + Bindings(SymbolBindings), +} + /// Live bindings for a single symbol at some point in control flow. Each live binding comes /// with a set of narrowing constraints and a visibility constraint. #[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update)] pub(super) struct SymbolBindings { + /// The narrowing constraint applicable to the "unbound" binding, if we need access to it even + /// when it's not visible. This happens in class scopes, where local bindings are not visible + /// to nested scopes, but we still need to know what narrowing constraints were applied to the + /// "unbound" binding. + unbound_narrowing_constraint: Option, /// A list of live bindings for this symbol, sorted by their `ScopedDefinitionId` live_bindings: SmallVec<[LiveBinding; INLINE_DEFINITIONS_PER_SYMBOL]>, } +impl SymbolBindings { + pub(super) fn unbound_narrowing_constraint(&self) -> ScopedNarrowingConstraint { + debug_assert!( + self.unbound_narrowing_constraint.is_some() + || self.live_bindings[0].binding.is_unbound() + ); + self.unbound_narrowing_constraint + .unwrap_or(self.live_bindings[0].narrowing_constraint) + } +} + /// One of the live bindings for a single symbol at some point in control flow. #[derive(Clone, Debug, PartialEq, Eq)] pub(super) struct LiveBinding { @@ -203,6 +234,7 @@ impl SymbolBindings { visibility_constraint: scope_start_visibility, }; Self { + unbound_narrowing_constraint: None, live_bindings: smallvec![initial_binding], } } @@ -212,7 +244,13 @@ impl SymbolBindings { &mut self, binding: ScopedDefinitionId, visibility_constraint: ScopedVisibilityConstraintId, + is_class_scope: bool, ) { + // If we are in a class scope, and the unbound binding was previously visible, but we will + // now replace it, record the narrowing constraints on it: + if is_class_scope && self.live_bindings[0].binding.is_unbound() { + self.unbound_narrowing_constraint = Some(self.live_bindings[0].narrowing_constraint); + } // The new binding replaces all previous live bindings in this path, and has no // constraints. self.live_bindings.clear(); @@ -278,6 +316,14 @@ impl SymbolBindings { ) { let a = std::mem::take(self); + if let Some((a, b)) = a + .unbound_narrowing_constraint + .zip(b.unbound_narrowing_constraint) + { + self.unbound_narrowing_constraint = + Some(narrowing_constraints.intersect_constraints(a, b)); + } + // Invariant: merge_join_by consumes the two iterators in sorted order, which ensures that // the merged `live_bindings` vec remains sorted. If a definition is found in both `a` and // `b`, we compose the constraints from the two paths in an appropriate way (intersection @@ -333,10 +379,11 @@ impl SymbolState { &mut self, binding_id: ScopedDefinitionId, visibility_constraint: ScopedVisibilityConstraintId, + is_class_scope: bool, ) { debug_assert_ne!(binding_id, ScopedDefinitionId::UNBOUND); self.bindings - .record_binding(binding_id, visibility_constraint); + .record_binding(binding_id, visibility_constraint, is_class_scope); } /// Add given constraint to all live bindings. @@ -467,6 +514,7 @@ mod tests { sym.record_binding( ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, + false, ); assert_bindings(&narrowing_constraints, &sym, &["1<>"]); @@ -479,6 +527,7 @@ mod tests { sym.record_binding( ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, + false, ); let predicate = ScopedPredicateId::from_u32(0).into(); sym.record_narrowing_constraint(&mut narrowing_constraints, predicate); @@ -496,6 +545,7 @@ mod tests { sym1a.record_binding( ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, + false, ); let predicate = ScopedPredicateId::from_u32(0).into(); sym1a.record_narrowing_constraint(&mut narrowing_constraints, predicate); @@ -504,6 +554,7 @@ mod tests { sym1b.record_binding( ScopedDefinitionId::from_u32(1), ScopedVisibilityConstraintId::ALWAYS_TRUE, + false, ); let predicate = ScopedPredicateId::from_u32(0).into(); sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); @@ -521,6 +572,7 @@ mod tests { sym2a.record_binding( ScopedDefinitionId::from_u32(2), ScopedVisibilityConstraintId::ALWAYS_TRUE, + false, ); let predicate = ScopedPredicateId::from_u32(1).into(); sym2a.record_narrowing_constraint(&mut narrowing_constraints, predicate); @@ -529,6 +581,7 @@ mod tests { sym1b.record_binding( ScopedDefinitionId::from_u32(2), ScopedVisibilityConstraintId::ALWAYS_TRUE, + false, ); let predicate = ScopedPredicateId::from_u32(2).into(); sym1b.record_narrowing_constraint(&mut narrowing_constraints, predicate); @@ -546,6 +599,7 @@ mod tests { sym3a.record_binding( ScopedDefinitionId::from_u32(3), ScopedVisibilityConstraintId::ALWAYS_TRUE, + false, ); let predicate = ScopedPredicateId::from_u32(3).into(); sym3a.record_narrowing_constraint(&mut narrowing_constraints, predicate); diff --git a/crates/ty_python_semantic/src/symbol.rs b/crates/ty_python_semantic/src/symbol.rs index 49b404eb71..767d4392f2 100644 --- a/crates/ty_python_semantic/src/symbol.rs +++ b/crates/ty_python_semantic/src/symbol.rs @@ -8,8 +8,8 @@ use crate::semantic_index::{ symbol_table, BindingWithConstraints, BindingWithConstraintsIterator, DeclarationsIterator, }; use crate::types::{ - binding_type, declaration_type, infer_narrowing_constraint, todo_type, IntersectionBuilder, - KnownClass, Truthiness, Type, TypeAndQualifiers, TypeQualifiers, UnionBuilder, UnionType, + binding_type, declaration_type, todo_type, KnownClass, Truthiness, Type, TypeAndQualifiers, + TypeQualifiers, UnionBuilder, UnionType, }; use crate::{resolve_module, Db, KnownModule, Program}; @@ -791,24 +791,8 @@ fn symbol_from_bindings_impl<'db>( return None; } - let constraint_tys: Vec<_> = narrowing_constraint - .filter_map(|constraint| infer_narrowing_constraint(db, constraint, binding)) - .collect(); - let binding_ty = binding_type(db, binding); - if constraint_tys.is_empty() { - Some(binding_ty) - } else { - let intersection_ty = constraint_tys - .into_iter() - .rev() - .fold( - IntersectionBuilder::new(db).add_positive(binding_ty), - IntersectionBuilder::add_positive, - ) - .build(); - Some(intersection_ty) - } + Some(narrowing_constraint.narrow(db, binding_ty, binding.symbol(db))) }, ); diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 87d9a44fb7..0719800bfc 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -54,10 +54,11 @@ use crate::semantic_index::definition::{ ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; +use crate::semantic_index::narrowing_constraints::ConstraintKey; use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind, }; -use crate::semantic_index::{semantic_index, EagerBindingsResult, SemanticIndex}; +use crate::semantic_index::{semantic_index, EagerSnapshotResult, SemanticIndex}; use crate::symbol::{ builtins_module_scope, builtins_symbol, explicit_global_symbol, module_type_implicit_global_symbol, symbol, symbol_from_bindings, symbol_from_declarations, @@ -5146,9 +5147,23 @@ impl<'db> TypeInferenceBuilder<'db> { let symbol_table = self.index.symbol_table(file_scope_id); let use_def = self.index.use_def_map(file_scope_id); + let mut constraint_keys = vec![]; + // Perform narrowing with applicable constraints between the current scope and the enclosing scope. + let narrow_with_applicable_constraints = |mut ty, constraint_keys: &[_]| { + for (enclosing_scope_file_id, constraint_key) in constraint_keys { + let use_def = self.index.use_def_map(*enclosing_scope_file_id); + let constraints = use_def.narrowing_constraints_at_use(*constraint_key); + let symbol_table = self.index.symbol_table(*enclosing_scope_file_id); + let symbol = symbol_table.symbol_id_by_name(symbol_name).unwrap(); + + ty = constraints.narrow(db, ty, symbol); + } + ty + }; + // If we're inferring types of deferred expressions, always treat them as public symbols - let local_scope_symbol = if self.is_deferred() { - if let Some(symbol_id) = symbol_table.symbol_id_by_name(symbol_name) { + let (local_scope_symbol, use_id) = if self.is_deferred() { + let symbol = if let Some(symbol_id) = symbol_table.symbol_id_by_name(symbol_name) { symbol_from_bindings(db, use_def.public_bindings(symbol_id)) } else { assert!( @@ -5156,10 +5171,12 @@ impl<'db> TypeInferenceBuilder<'db> { "Expected the symbol table to create a symbol for every Name node" ); Symbol::Unbound - } + }; + (symbol, None) } else { let use_id = name_node.scoped_use_id(db, scope); - symbol_from_bindings(db, use_def.bindings_at_use(use_id)) + let symbol = symbol_from_bindings(db, use_def.bindings_at_use(use_id)); + (symbol, Some(use_id)) }; let symbol = SymbolAndQualifiers::from(local_scope_symbol).or_fall_back_to(db, || { @@ -5187,6 +5204,10 @@ impl<'db> TypeInferenceBuilder<'db> { return Symbol::Unbound.into(); } + if let Some(use_id) = use_id { + constraint_keys.push((file_scope_id, ConstraintKey::UseId(use_id))); + } + let current_file = self.file(); // Walk up parent scopes looking for a possible enclosing scope that may have a @@ -5200,14 +5221,12 @@ impl<'db> TypeInferenceBuilder<'db> { // There is one exception to this rule: type parameter scopes can see // names defined in an immediately-enclosing class scope. let enclosing_scope_id = enclosing_scope_file_id.to_scope_id(db, current_file); + let is_immediately_enclosing_scope = scope.is_type_parameter(db) && scope .scope(db) .parent() .is_some_and(|parent| parent == enclosing_scope_file_id); - if !enclosing_scope_id.is_function_like(db) && !is_immediately_enclosing_scope { - continue; - } // If the reference is in a nested eager scope, we need to look for the symbol at // the point where the previous enclosing scope was defined, instead of at the end @@ -5216,23 +5235,42 @@ impl<'db> TypeInferenceBuilder<'db> { // enclosing scopes that actually contain bindings that we should use when // resolving the reference.) if !self.is_deferred() { - match self.index.eager_bindings( + match self.index.eager_snapshot( enclosing_scope_file_id, symbol_name, file_scope_id, ) { - EagerBindingsResult::Found(bindings) => { - return symbol_from_bindings(db, bindings).into(); + EagerSnapshotResult::FoundConstraint(constraint) => { + constraint_keys.push(( + enclosing_scope_file_id, + ConstraintKey::NarrowingConstraint(constraint), + )); } - // There are no visible bindings here. + EagerSnapshotResult::FoundBindings(bindings) => { + if !enclosing_scope_id.is_function_like(db) + && !is_immediately_enclosing_scope + { + continue; + } + return symbol_from_bindings(db, bindings) + .map_type(|ty| { + narrow_with_applicable_constraints(ty, &constraint_keys) + }) + .into(); + } + // There are no visible bindings / constraint here. // Don't fall back to non-eager symbol resolution. - EagerBindingsResult::NotFound => { + EagerSnapshotResult::NotFound => { continue; } - EagerBindingsResult::NoLongerInEagerContext => {} + EagerSnapshotResult::NoLongerInEagerContext => {} } } + if !enclosing_scope_id.is_function_like(db) && !is_immediately_enclosing_scope { + continue; + } + let enclosing_symbol_table = self.index.symbol_table(enclosing_scope_file_id); let Some(enclosing_symbol) = enclosing_symbol_table.symbol_by_name(symbol_name) else { @@ -5244,7 +5282,8 @@ impl<'db> TypeInferenceBuilder<'db> { // runtime, it is the scope that creates the cell for our closure.) If the name // isn't bound in that scope, we should get an unbound name, not continue // falling back to other scopes / globals / builtins. - return symbol(db, enclosing_scope_id, symbol_name); + return symbol(db, enclosing_scope_id, symbol_name) + .map_type(|ty| narrow_with_applicable_constraints(ty, &constraint_keys)); } } @@ -5257,28 +5296,42 @@ impl<'db> TypeInferenceBuilder<'db> { } if !self.is_deferred() { - match self.index.eager_bindings( + match self.index.eager_snapshot( FileScopeId::global(), symbol_name, file_scope_id, ) { - EagerBindingsResult::Found(bindings) => { - return symbol_from_bindings(db, bindings).into(); + EagerSnapshotResult::FoundConstraint(constraint) => { + constraint_keys.push(( + FileScopeId::global(), + ConstraintKey::NarrowingConstraint(constraint), + )); } - // There are no visible bindings here. - EagerBindingsResult::NotFound => { + EagerSnapshotResult::FoundBindings(bindings) => { + return symbol_from_bindings(db, bindings) + .map_type(|ty| { + narrow_with_applicable_constraints(ty, &constraint_keys) + }) + .into(); + } + // There are no visible bindings / constraint here. + EagerSnapshotResult::NotFound => { return Symbol::Unbound.into(); } - EagerBindingsResult::NoLongerInEagerContext => {} + EagerSnapshotResult::NoLongerInEagerContext => {} } } explicit_global_symbol(db, self.file(), symbol_name) + .map_type(|ty| narrow_with_applicable_constraints(ty, &constraint_keys)) }) // Not found in the module's explicitly declared global symbols? // Check the "implicit globals" such as `__doc__`, `__file__`, `__name__`, etc. // These are looked up as attributes on `types.ModuleType`. - .or_fall_back_to(db, || module_type_implicit_global_symbol(db, symbol_name)) + .or_fall_back_to(db, || { + module_type_implicit_global_symbol(db, symbol_name) + .map_type(|ty| narrow_with_applicable_constraints(ty, &constraint_keys)) + }) // Not found in globals? Fallback to builtins // (without infinite recursion if we're already in builtins.) .or_fall_back_to(db, || { diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 71ed52f665..b5a96c1a70 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1,5 +1,4 @@ use crate::semantic_index::ast_ids::HasScopedExpressionId; -use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::predicate::{ PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, @@ -21,7 +20,7 @@ use std::sync::Arc; use super::UnionType; -/// Return the type constraint that `test` (if true) would place on `definition`, if any. +/// Return the type constraint that `test` (if true) would place on `symbol`, if any. /// /// For example, if we have this code: /// @@ -35,12 +34,12 @@ use super::UnionType; /// The `test` expression `x is not None` places the constraint "not None" on the definition of /// `x`, so in that case we'd return `Some(Type::Intersection(negative=[Type::None]))`. /// -/// But if we called this with the same `test` expression, but the `definition` of `y`, no -/// constraint is applied to that definition, so we'd just return `None`. +/// But if we called this with the same `test` expression, but the `symbol` of `y`, no +/// constraint is applied to that symbol, so we'd just return `None`. pub(crate) fn infer_narrowing_constraint<'db>( db: &'db dyn Db, predicate: Predicate<'db>, - definition: Definition<'db>, + symbol: ScopedSymbolId, ) -> Option> { let constraints = match predicate.node { PredicateNode::Expression(expression) => { @@ -60,7 +59,7 @@ pub(crate) fn infer_narrowing_constraint<'db>( PredicateNode::StarImportPlaceholder(_) => return None, }; if let Some(constraints) = constraints { - constraints.get(&definition.symbol(db)).copied() + constraints.get(&symbol).copied() } else { None }