diff --git a/crates/ty_python_semantic/resources/mdtest/import/star.md b/crates/ty_python_semantic/resources/mdtest/import/star.md index 54d2050259..14cff45efc 100644 --- a/crates/ty_python_semantic/resources/mdtest/import/star.md +++ b/crates/ty_python_semantic/resources/mdtest/import/star.md @@ -1336,6 +1336,69 @@ reveal_type(g) # revealed: Unknown reveal_type(h) # revealed: Unknown ``` +## Star-imports can affect member states + +If a star-import pulls in a symbol that was previously defined in the importing module (e.g. `obj`), +it can affect the state of associated member expressions (e.g. `obj.attr` or `obj[0]`). In the test +below, note how the types of the corresponding attribute expressions change after the star import +affects the object: + +`common.py`: + +```py +class C: + attr: int | None +``` + +`exporter.py`: + +```py +from common import C + +def flag() -> bool: + return True + +should_be_imported: C = C() + +if flag(): + might_be_imported: C = C() + +if False: + should_not_be_imported: C = C() +``` + +`main.py`: + +```py +from common import C + +should_be_imported = C() +might_be_imported = C() +should_not_be_imported = C() + +# We start with the plain attribute types: +reveal_type(should_be_imported.attr) # revealed: int | None +reveal_type(might_be_imported.attr) # revealed: int | None +reveal_type(should_not_be_imported.attr) # revealed: int | None + +# Now we narrow the types by assignment: +should_be_imported.attr = 1 +might_be_imported.attr = 1 +should_not_be_imported.attr = 1 + +reveal_type(should_be_imported.attr) # revealed: Literal[1] +reveal_type(might_be_imported.attr) # revealed: Literal[1] +reveal_type(should_not_be_imported.attr) # revealed: Literal[1] + +# This star import adds bindings for `should_be_imported` and `might_be_imported`: +from exporter import * + +# As expected, narrowing is "reset" for the first two variables, but not for the third: +reveal_type(should_be_imported.attr) # revealed: int | None +reveal_type(might_be_imported.attr) # revealed: int | None +reveal_type(should_not_be_imported.attr) # revealed: Literal[1] +``` + ## Cyclic star imports Believe it or not, this code does *not* raise an exception at runtime! diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 66a0f6f428..9ba2069784 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -1616,9 +1616,12 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { let star_import_predicate = self.add_predicate(star_import.into()); + let associated_member_ids = self.place_tables[self.current_scope()] + .associated_place_ids(ScopedPlaceId::Symbol(symbol_id)); let pre_definition = self .current_use_def_map() - .single_symbol_place_snapshot(symbol_id); + .single_symbol_snapshot(symbol_id, associated_member_ids); + let pre_definition_reachability = self.current_use_def_map().reachability; diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 05fa369521..a7c7520806 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -801,6 +801,13 @@ pub(super) struct FlowSnapshot { reachability: ScopedReachabilityConstraintId, } +/// A snapshot of the state of a single symbol (e.g. `obj`) and all of its associated members +/// (e.g. `obj.attr`, `obj["key"]`). +pub(super) struct SingleSymbolSnapshot { + symbol_state: PlaceState, + associated_member_states: FxHashMap, +} + #[derive(Debug)] pub(super) struct UseDefMapBuilder<'db> { /// Append-only array of [`DefinitionState`]. @@ -991,13 +998,26 @@ impl<'db> UseDefMapBuilder<'db> { } } - /// Snapshot the state of a single place at the current point in control flow. + /// Snapshot the state of a single symbol and all of its associated members, at the current + /// point in control flow. /// /// This is only used for `*`-import reachability constraints, which are handled differently /// to most other reachability constraints. See the doc-comment for /// [`Self::record_and_negate_star_import_reachability_constraint`] for more details. - pub(super) fn single_symbol_place_snapshot(&self, symbol: ScopedSymbolId) -> PlaceState { - self.symbol_states[symbol].clone() + pub(super) fn single_symbol_snapshot( + &self, + symbol: ScopedSymbolId, + associated_member_ids: &[ScopedMemberId], + ) -> SingleSymbolSnapshot { + let symbol_state = self.symbol_states[symbol].clone(); + let mut associated_member_states = FxHashMap::default(); + for &member_id in associated_member_ids { + associated_member_states.insert(member_id, self.member_states[member_id].clone()); + } + SingleSymbolSnapshot { + symbol_state, + associated_member_states, + } } /// This method exists solely for handling `*`-import reachability constraints. @@ -1033,14 +1053,14 @@ impl<'db> UseDefMapBuilder<'db> { &mut self, reachability_id: ScopedReachabilityConstraintId, symbol: ScopedSymbolId, - pre_definition_state: PlaceState, + pre_definition: SingleSymbolSnapshot, ) { let negated_reachability_id = self .reachability_constraints .add_not_constraint(reachability_id); let mut post_definition_state = - std::mem::replace(&mut self.symbol_states[symbol], pre_definition_state); + std::mem::replace(&mut self.symbol_states[symbol], pre_definition.symbol_state); post_definition_state .record_reachability_constraint(&mut self.reachability_constraints, reachability_id); @@ -1055,6 +1075,30 @@ impl<'db> UseDefMapBuilder<'db> { &mut self.narrowing_constraints, &mut self.reachability_constraints, ); + + // And similarly for all associated members: + for (member_id, pre_definition_member_state) in pre_definition.associated_member_states { + let mut post_definition_state = std::mem::replace( + &mut self.member_states[member_id], + pre_definition_member_state, + ); + + post_definition_state.record_reachability_constraint( + &mut self.reachability_constraints, + reachability_id, + ); + + self.member_states[member_id].record_reachability_constraint( + &mut self.reachability_constraints, + negated_reachability_id, + ); + + self.member_states[member_id].merge( + post_definition_state, + &mut self.narrowing_constraints, + &mut self.reachability_constraints, + ); + } } pub(super) fn record_reachability_constraint(