From 73399029b200f779236bdcf2a6f491bc521bf6b0 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 9 Apr 2025 17:53:26 +0100 Subject: [PATCH] [red-knot] Optimise visibility constraints for `*`-import definitions (#17317) --- .../src/semantic_index/builder.rs | 75 +++++++++++++------ .../src/semantic_index/use_def.rs | 53 ++++++++++++- 2 files changed, 103 insertions(+), 25 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 2eba5016af..34e2755f25 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -331,12 +331,15 @@ impl<'db> SemanticIndexBuilder<'db> { self.current_use_def_map_mut().merge(state); } - fn add_symbol(&mut self, name: Name) -> ScopedSymbolId { + /// Return a 2-element tuple, where the first element is the [`ScopedSymbolId`] of the + /// symbol added, and the second element is a boolean indicating whether the symbol was *newly* + /// added or not + fn add_symbol(&mut self, name: Name) -> (ScopedSymbolId, bool) { let (symbol_id, added) = self.current_symbol_table().add_symbol(name); if added { self.current_use_def_map_mut().add_symbol(symbol_id); } - symbol_id + (symbol_id, added) } fn mark_symbol_bound(&mut self, id: ScopedSymbolId) { @@ -516,6 +519,7 @@ impl<'db> SemanticIndexBuilder<'db> { } /// Records a visibility constraint by applying it to all live bindings and declarations. + #[must_use = "A visibility constraint must always be negated after it is added"] fn record_visibility_constraint( &mut self, predicate: Predicate<'db>, @@ -747,7 +751,7 @@ impl<'db> SemanticIndexBuilder<'db> { .. }) => (name, &None, default), }; - let symbol = self.add_symbol(name.id.clone()); + let (symbol, _) = self.add_symbol(name.id.clone()); // TODO create Definition for PEP 695 typevars // note that the "bound" on the typevar is a totally different thing than whether // or not a name is "bound" by a typevar declaration; the latter is always true. @@ -841,20 +845,20 @@ impl<'db> SemanticIndexBuilder<'db> { self.declare_parameter(parameter); } if let Some(vararg) = parameters.vararg.as_ref() { - let symbol = self.add_symbol(vararg.name.id().clone()); + let (symbol, _) = self.add_symbol(vararg.name.id().clone()); self.add_definition( symbol, DefinitionNodeRef::VariadicPositionalParameter(vararg), ); } if let Some(kwarg) = parameters.kwarg.as_ref() { - let symbol = self.add_symbol(kwarg.name.id().clone()); + let (symbol, _) = self.add_symbol(kwarg.name.id().clone()); self.add_definition(symbol, DefinitionNodeRef::VariadicKeywordParameter(kwarg)); } } fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) { - let symbol = self.add_symbol(parameter.name().id().clone()); + let (symbol, _) = self.add_symbol(parameter.name().id().clone()); let definition = self.add_definition(symbol, parameter); @@ -1071,7 +1075,7 @@ where // The symbol for the function name itself has to be evaluated // at the end to match the runtime evaluation of parameter defaults // and return-type annotations. - let symbol = self.add_symbol(name.id.clone()); + let (symbol, _) = self.add_symbol(name.id.clone()); self.add_definition(symbol, function_def); } ast::Stmt::ClassDef(class) => { @@ -1095,11 +1099,11 @@ where ); // In Python runtime semantics, a class is registered after its scope is evaluated. - let symbol = self.add_symbol(class.name.id.clone()); + let (symbol, _) = self.add_symbol(class.name.id.clone()); self.add_definition(symbol, class); } ast::Stmt::TypeAlias(type_alias) => { - let symbol = self.add_symbol( + let (symbol, _) = self.add_symbol( type_alias .name .as_name_expr() @@ -1133,7 +1137,7 @@ where (Name::new(alias.name.id.split('.').next().unwrap()), false) }; - let symbol = self.add_symbol(symbol_name); + let (symbol, _) = self.add_symbol(symbol_name); self.add_definition( symbol, ImportDefinitionNodeRef { @@ -1200,7 +1204,7 @@ where // // For more details, see the doc-comment on `StarImportPlaceholderPredicate`. for export in exported_names(self.db, referenced_module) { - let symbol_id = self.add_symbol(export.clone()); + let (symbol_id, newly_added) = self.add_symbol(export.clone()); let node_ref = StarImportDefinitionNodeRef { node, symbol_id }; let star_import = StarImportPlaceholderPredicate::new( self.db, @@ -1210,13 +1214,38 @@ where ); let pre_definition = self.flow_snapshot(); self.push_additional_definition(symbol_id, node_ref); - let constraint_id = - self.record_visibility_constraint(star_import.into()); - let post_definition = self.flow_snapshot(); - self.flow_restore(pre_definition.clone()); - self.record_negated_visibility_constraint(constraint_id); - self.flow_merge(post_definition); - self.simplify_visibility_constraints(pre_definition); + + // Fast path for if there were no previous definitions + // of the symbol defined through the `*` import: + // we can apply the visibility constraint to *only* the added definition, + // rather than all definitions + if newly_added { + let constraint_id = self + .current_use_def_map_mut() + .record_star_import_visibility_constraint( + star_import, + symbol_id, + ); + + let post_definition = self.flow_snapshot(); + self.flow_restore(pre_definition); + + self.current_use_def_map_mut() + .negate_star_import_visibility_constraint( + symbol_id, + constraint_id, + ); + + self.flow_merge(post_definition); + } else { + let constraint_id = + self.record_visibility_constraint(star_import.into()); + let post_definition = self.flow_snapshot(); + self.flow_restore(pre_definition.clone()); + self.record_negated_visibility_constraint(constraint_id); + self.flow_merge(post_definition); + self.simplify_visibility_constraints(pre_definition); + } } continue; @@ -1236,7 +1265,7 @@ where self.has_future_annotations |= alias.name.id == "annotations" && node.module.as_deref() == Some("__future__"); - let symbol = self.add_symbol(symbol_name.clone()); + let (symbol, _) = self.add_symbol(symbol_name.clone()); self.add_definition( symbol, @@ -1636,7 +1665,7 @@ where // which is invalid syntax. However, it's still pretty obvious here that the user // *wanted* `e` to be bound, so we should still create a definition here nonetheless. if let Some(symbol_name) = symbol_name { - let symbol = self.add_symbol(symbol_name.id.clone()); + let (symbol, _) = self.add_symbol(symbol_name.id.clone()); self.add_definition( symbol, @@ -1721,7 +1750,7 @@ where (ast::ExprContext::Del, _) => (false, true), (ast::ExprContext::Invalid, _) => (false, false), }; - let symbol = self.add_symbol(id.clone()); + let (symbol, _) = self.add_symbol(id.clone()); if is_use { self.mark_symbol_used(symbol); @@ -2007,7 +2036,7 @@ where range: _, }) = pattern { - let symbol = self.add_symbol(name.id().clone()); + let (symbol, _) = self.add_symbol(name.id().clone()); let state = self.current_match_case.as_ref().unwrap(); self.add_definition( symbol, @@ -2028,7 +2057,7 @@ where rest: Some(name), .. }) = pattern { - let symbol = self.add_symbol(name.id().clone()); + let (symbol, _) = self.add_symbol(name.id().clone()); let state = self.current_match_case.as_ref().unwrap(); self.add_definition( symbol, diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index 147367473d..057fb316c6 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -269,7 +269,7 @@ use crate::semantic_index::narrowing_constraints::{ NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator, }; use crate::semantic_index::predicate::{ - Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, + Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate, }; use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId}; use crate::semantic_index::visibility_constraints::{ @@ -603,7 +603,7 @@ pub(super) struct UseDefMapBuilder<'db> { /// x # we store a reachability constraint of [test] for this use of `x` /// /// y = 2 - /// + /// /// # we record a visibility constraint of [test] here, which retroactively affects /// # the `y = 1` and the `y = 2` binding. /// else: @@ -701,6 +701,34 @@ impl<'db> UseDefMapBuilder<'db> { .add_and_constraint(self.scope_start_visibility, constraint); } + #[must_use = "A `*`-import visibility constraint must always be negated after it is added"] + pub(super) fn record_star_import_visibility_constraint( + &mut self, + star_import: StarImportPlaceholderPredicate<'db>, + symbol: ScopedSymbolId, + ) -> StarImportVisibilityConstraintId { + let predicate_id = self.add_predicate(star_import.into()); + let visibility_id = self.visibility_constraints.add_atom(predicate_id); + self.symbol_states[symbol] + .record_visibility_constraint(&mut self.visibility_constraints, visibility_id); + StarImportVisibilityConstraintId(visibility_id) + } + + pub(super) fn negate_star_import_visibility_constraint( + &mut self, + symbol_id: ScopedSymbolId, + constraint: StarImportVisibilityConstraintId, + ) { + let negated_constraint = self + .visibility_constraints + .add_not_constraint(constraint.into_scoped_constraint_id()); + self.symbol_states[symbol_id] + .record_visibility_constraint(&mut self.visibility_constraints, negated_constraint); + self.scope_start_visibility = self + .visibility_constraints + .add_and_constraint(self.scope_start_visibility, negated_constraint); + } + /// This method resets the visibility constraints for all symbols to a previous state /// *if* there have been no new declarations or bindings since then. Consider the /// following example: @@ -900,3 +928,24 @@ impl<'db> UseDefMapBuilder<'db> { } } } + +/// Newtype wrapper over [`ScopedVisibilityConstraintId`] to improve type safety. +/// +/// By returning this type from [`UseDefMapBuilder::record_star_import_visibility_constraint`] +/// rather than [`ScopedVisibilityConstraintId`] directly, we ensure that +/// [`UseDefMapBuilder::negate_star_import_visibility_constraint`] must be called after the +/// visibility constraint has been added, and we ensure that +/// [`super::SemanticIndexBuilder::record_negated_visibility_constraint`] *cannot* be called with +/// the narrowing constraint (which would lead to incorrect behaviour). +/// +/// This type is defined here rather than in the [`super::visibility_constraints`] module +/// because it should only ever be constructed and deconstructed from methods in the +/// [`UseDefMapBuilder`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) struct StarImportVisibilityConstraintId(ScopedVisibilityConstraintId); + +impl StarImportVisibilityConstraintId { + fn into_scoped_constraint_id(self) -> ScopedVisibilityConstraintId { + self.0 + } +}