diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index 440fc1b8a4..b99b1cfb56 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -193,6 +193,12 @@ pub(crate) enum EagerSnapshotResult<'map, 'db> { NoLongerInEagerContext, } +#[derive(Debug, Update, get_size2::GetSize)] +pub(crate) enum NotLocalVariableKind { + Nonlocal, + Global, +} + /// The place tables and use-def maps for all scopes in a file. #[derive(Debug, Update, get_size2::GetSize)] pub(crate) struct SemanticIndex<'db> { @@ -217,11 +223,8 @@ pub(crate) struct SemanticIndex<'db> { /// Map from the file-local [`FileScopeId`] to the salsa-ingredient [`ScopeId`]. scope_ids_by_scope: IndexVec>, - /// Map from the file-local [`FileScopeId`] to the set of explicit-global symbols it contains. - globals_by_scope: FxHashMap>, - /// Map from the file-local [`FileScopeId`] to the set of explicit-nonlocal symbols it contains. - nonlocals_by_scope: FxHashMap>, + not_locals_by_scope: FxHashMap>, /// Use-def map for each scope in this file. use_def_maps: IndexVec>, @@ -311,9 +314,10 @@ impl<'db> SemanticIndex<'db> { symbol: ScopedPlaceId, scope: FileScopeId, ) -> bool { - self.globals_by_scope - .get(&scope) - .is_some_and(|globals| globals.contains(&symbol)) + let Some(scope) = self.not_locals_by_scope.get(&scope) else { + return false; + }; + matches!(scope.get(&symbol), Some(NotLocalVariableKind::Global)) } pub(crate) fn symbol_is_nonlocal_in_scope( @@ -321,9 +325,10 @@ impl<'db> SemanticIndex<'db> { symbol: ScopedPlaceId, scope: FileScopeId, ) -> bool { - self.nonlocals_by_scope - .get(&scope) - .is_some_and(|nonlocals| nonlocals.contains(&symbol)) + let Some(scope) = self.not_locals_by_scope.get(&scope) else { + return false; + }; + matches!(scope.get(&symbol), Some(NotLocalVariableKind::Nonlocal)) } /// Returns the id of the parent scope. diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 9d838d2177..e7de02c6a3 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -45,7 +45,7 @@ use crate::semantic_index::reachability_constraints::{ use crate::semantic_index::use_def::{ EagerSnapshotKey, FlowSnapshot, ScopedEagerSnapshotId, UseDefMapBuilder, }; -use crate::semantic_index::{ArcUseDefMap, SemanticIndex}; +use crate::semantic_index::{ArcUseDefMap, NotLocalVariableKind, SemanticIndex}; use crate::unpack::{Unpack, UnpackKind, UnpackPosition, UnpackValue}; use crate::{Db, Program}; @@ -103,8 +103,7 @@ pub(super) struct SemanticIndexBuilder<'db, 'ast> { use_def_maps: IndexVec>, scopes_by_node: FxHashMap, scopes_by_expression: FxHashMap, - globals_by_scope: FxHashMap>, - nonlocals_by_scope: FxHashMap>, + not_locals_by_scope: FxHashMap>, definitions_by_node: FxHashMap>, expressions_by_node: FxHashMap>, imported_modules: FxHashSet, @@ -142,8 +141,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { scopes_by_node: FxHashMap::default(), definitions_by_node: FxHashMap::default(), expressions_by_node: FxHashMap::default(), - globals_by_scope: FxHashMap::default(), - nonlocals_by_scope: FxHashMap::default(), + not_locals_by_scope: FxHashMap::default(), imported_modules: FxHashSet::default(), generator_functions: FxHashSet::default(), @@ -1048,8 +1046,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { self.scopes_by_node.shrink_to_fit(); self.generator_functions.shrink_to_fit(); self.eager_snapshots.shrink_to_fit(); - self.globals_by_scope.shrink_to_fit(); - self.nonlocals_by_scope.shrink_to_fit(); + self.not_locals_by_scope.shrink_to_fit(); SemanticIndex { place_tables, @@ -1057,8 +1054,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { definitions_by_node: self.definitions_by_node, expressions_by_node: self.expressions_by_node, scope_ids_by_scope: self.scope_ids_by_scope, - globals_by_scope: self.globals_by_scope, - nonlocals_by_scope: self.nonlocals_by_scope, + not_locals_by_scope: self.not_locals_by_scope, ast_ids, scopes_by_expression: self.scopes_by_expression, scopes_by_node: self.scopes_by_node, @@ -1429,25 +1425,22 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { if let ast::Expr::Name(name) = &*node.target { let symbol_id = self.add_symbol(name.id.clone()); let scope_id = self.current_scope(); - // Check whether the variable has been declared global. - if let Some(globals) = self.globals_by_scope.get(&scope_id) { - if globals.contains(&symbol_id) { + // Check whether the variable has been declared global or nonlocal. + if let Some(not_locals) = self.not_locals_by_scope.get(&scope_id) { + if let Some(not_local_kind) = not_locals.get(&symbol_id) { self.report_semantic_error(SemanticSyntaxError { - kind: SemanticSyntaxErrorKind::AnnotatedGlobal( - name.id.as_str().into(), - ), - range: name.range, - python_version: self.python_version, - }); - } - } - // Check whether the variable has been declared nonlocal. - if let Some(nonlocals) = self.nonlocals_by_scope.get(&scope_id) { - if nonlocals.contains(&symbol_id) { - self.report_semantic_error(SemanticSyntaxError { - kind: SemanticSyntaxErrorKind::AnnotatedNonlocal( - name.id.as_str().into(), - ), + kind: match not_local_kind { + NotLocalVariableKind::Global => { + SemanticSyntaxErrorKind::AnnotatedGlobal( + name.id.as_str().into(), + ) + } + NotLocalVariableKind::Nonlocal => { + SemanticSyntaxErrorKind::AnnotatedNonlocal( + name.id.as_str().into(), + ) + } + }, range: name.range, python_version: self.python_version, }); @@ -1910,8 +1903,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { } let scope_id = self.current_scope(); // Check whether the variable has also been declared nonlocal. - if let Some(nonlocals) = self.nonlocals_by_scope.get(&scope_id) { - if nonlocals.contains(&symbol_id) { + if let Some(not_locals) = self.not_locals_by_scope.get(&scope_id) { + if let Some(NotLocalVariableKind::Nonlocal) = not_locals.get(&symbol_id) { self.report_semantic_error(SemanticSyntaxError { kind: SemanticSyntaxErrorKind::NonlocalAndGlobal(name.to_string()), range: name.range, @@ -1919,10 +1912,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { }); } } - self.globals_by_scope + self.not_locals_by_scope .entry(scope_id) .or_default() - .insert(symbol_id); + .insert(symbol_id, NotLocalVariableKind::Global); } walk_stmt(self, stmt); } @@ -1951,8 +1944,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { } let scope_id = self.current_scope(); // Check whether the variable has also been declared global. - if let Some(globals) = self.globals_by_scope.get(&scope_id) { - if globals.contains(&symbol_id) { + if let Some(not_locals) = self.not_locals_by_scope.get(&scope_id) { + if let Some(NotLocalVariableKind::Global) = not_locals.get(&symbol_id) { self.report_semantic_error(SemanticSyntaxError { kind: SemanticSyntaxErrorKind::NonlocalAndGlobal(name.to_string()), range: name.range, @@ -1960,10 +1953,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { }); } } - self.nonlocals_by_scope + self.not_locals_by_scope .entry(scope_id) .or_default() - .insert(symbol_id); + .insert(symbol_id, NotLocalVariableKind::Nonlocal); } walk_stmt(self, stmt); }