From 3ac58b47bd09106ca7a60ddd58a3e0580de7ef77 Mon Sep 17 00:00:00 2001 From: Andrew Gallant Date: Wed, 10 Dec 2025 15:51:10 -0500 Subject: [PATCH] [ty] Support `__all__ += submodule.__all__` ... and also `__all__.extend(submodule.__all__)`. I originally left out support for this since I was unclear on whether we'd really need it. But it turns out this is used somewhat frequently. For example, in `numpy`. See the comments on the new `Imports` type for how we approach this. --- crates/ty_ide/src/symbols.rs | 615 ++++++++++++++++++++++++++++++++++- 1 file changed, 608 insertions(+), 7 deletions(-) diff --git a/crates/ty_ide/src/symbols.rs b/crates/ty_ide/src/symbols.rs index 08f659debe..7d3f87d516 100644 --- a/crates/ty_ide/src/symbols.rs +++ b/crates/ty_ide/src/symbols.rs @@ -10,10 +10,10 @@ use ruff_db::files::File; use ruff_db::parsed::parsed_module; use ruff_index::{IndexVec, newtype_index}; use ruff_python_ast as ast; -use ruff_python_ast::name::Name; +use ruff_python_ast::name::{Name, UnqualifiedName}; use ruff_python_ast::visitor::source_order::{self, SourceOrderVisitor}; use ruff_text_size::{Ranged, TextRange}; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashMap, FxHashSet}; use ty_project::Db; use ty_python_semantic::{ModuleName, resolve_module}; @@ -375,7 +375,11 @@ pub(crate) fn symbols_for_file(db: &dyn Db, file: File) -> FlatSymbols { /// While callers can convert this into a hierarchical collection of /// symbols, it won't result in anything meaningful since the flat list /// returned doesn't include children. -#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)] +#[salsa::tracked( + returns(ref), + cycle_initial=symbols_for_file_global_only_cycle_initial, + heap_size=ruff_memory_usage::heap_size, +)] pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbols { let parsed = parsed_module(db, file); let module = parsed.load(db); @@ -394,6 +398,14 @@ pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbo visitor.into_flat_symbols() } +fn symbols_for_file_global_only_cycle_initial( + _db: &dyn Db, + _id: salsa::Id, + _file: File, +) -> FlatSymbols { + FlatSymbols::default() +} + #[derive(Debug, Clone, PartialEq, Eq, get_size2::GetSize)] struct SymbolTree { parent: Option, @@ -411,6 +423,189 @@ enum ImportKind { Wildcard, } +/// An abstraction for managing module scope imports. +/// +/// This is meant to recognize the following idioms for updating +/// `__all__` in module scope: +/// +/// ```ignore +/// __all__ += submodule.__all__ +/// __all__.extend(submodule.__all__) +/// ``` +/// +/// # Correctness +/// +/// The approach used here is not correct 100% of the time. +/// For example, it is somewhat easy to defeat it: +/// +/// ```ignore +/// from numpy import * +/// from importlib import resources +/// import numpy as np +/// np = resources +/// __all__ = [] +/// __all__ += np.__all__ +/// ``` +/// +/// In this example, `np` will still be resolved to the `numpy` +/// module instead of the `importlib.resources` module. Namely, this +/// abstraction doesn't track all definitions. This would result in a +/// silently incorrect `__all__`. +/// +/// This abstraction does handle the case when submodules are imported. +/// Namely, we do get this case correct: +/// +/// ```ignore +/// from importlib.resources import * +/// from importlib import resources +/// __all__ = [] +/// __all__ += resources.__all__ +/// ``` +/// +/// We do this by treating all imports in a `from ... import ...` +/// statement as *possible* modules. Then when we lookup `resources`, +/// we attempt to resolve it to an actual module. If that fails, then +/// we consider `__all__` invalid. +/// +/// There are likely many many other cases that we don't handle as +/// well, which ty does (it has its own `__all__` parsing using types +/// to deal with this case). We can add handling for those as they +/// come up in real world examples. +/// +/// # Performance +/// +/// This abstraction recognizes that, compared to all possible imports, +/// it is very rare to use one of them to update `__all__`. Therefore, +/// we are careful not to do too much work up-front (like eagerly +/// manifesting `ModuleName` values). +#[derive(Clone, Debug, Default, get_size2::GetSize)] +struct Imports<'db> { + /// A map from the name that a module is available + /// under to its actual module name (and our level + /// of certainty that it ought to be treated as a module). + module_names: FxHashMap<&'db str, ImportModuleKind<'db>>, +} + +impl<'db> Imports<'db> { + /// Track the imports from the given `import ...` statement. + fn add_import(&mut self, import: &'db ast::StmtImport) { + for alias in &import.names { + let asname = alias + .asname + .as_ref() + .map(|ident| &ident.id) + .unwrap_or(&alias.name.id); + let module_name = ImportModuleName::Import(&alias.name.id); + self.module_names + .insert(asname, ImportModuleKind::Definitive(module_name)); + } + } + + /// Track the imports from the given `from ... import ...` statement. + fn add_import_from(&mut self, import_from: &'db ast::StmtImportFrom) { + for alias in &import_from.names { + if &alias.name == "*" { + // FIXME: We'd ideally include the names + // imported from the module, but we don't + // want to do this eagerly. So supporting + // this requires more infrastructure in + // `Imports`. + continue; + } + + let asname = alias + .asname + .as_ref() + .map(|ident| &ident.id) + .unwrap_or(&alias.name.id); + let module_name = ImportModuleName::ImportFrom { + parent: import_from, + child: &alias.name.id, + }; + self.module_names + .insert(asname, ImportModuleKind::Possible(module_name)); + } + } + + /// Return the symbols exported by the module referred to by `name`. + /// + /// e.g., This can be used to resolve `__all__ += submodule.__all__`, + /// where `name` is `submodule`. + fn get_module_symbols( + &self, + db: &'db dyn Db, + importing_file: File, + name: &Name, + ) -> Option<&'db FlatSymbols> { + let module_name = match self.module_names.get(name.as_str())? { + ImportModuleKind::Definitive(name) | ImportModuleKind::Possible(name) => { + name.to_module_name(db, importing_file)? + } + }; + let module = resolve_module(db, importing_file, &module_name)?; + Some(symbols_for_file_global_only(db, module.file(db)?)) + } +} + +/// Describes the level of certainty that an import is a module. +/// +/// For example, `import foo`, then `foo` is definitively a module. +/// But `from quux import foo`, then `quux.foo` is possibly a module. +#[derive(Debug, Clone, Copy, get_size2::GetSize)] +enum ImportModuleKind<'db> { + Definitive(ImportModuleName<'db>), + Possible(ImportModuleName<'db>), +} + +/// A representation of something that can be turned into a +/// `ModuleName`. +/// +/// We don't do this eagerly, and instead represent the constituent +/// pieces, in order to avoid the work needed to build a `ModuleName`. +/// In particular, it is somewhat rare for the visitor to need +/// to access the imports found in a module. At time of writing +/// (2025-12-10), this only happens when referencing a submodule +/// to augment an `__all__` definition. For example, as found in +/// `matplotlib`: +/// +/// ```ignore +/// import numpy as np +/// __all__ = ['rand', 'randn', 'repmat'] +/// __all__ += np.__all__ +/// ``` +/// +/// This construct is somewhat rare and it would be sad to allocate a +/// `ModuleName` for every imported item unnecessarily. +#[derive(Debug, Clone, Copy, get_size2::GetSize)] +enum ImportModuleName<'db> { + /// The `foo` in `import quux, foo as blah, baz`. + Import(&'db Name), + /// A possible module in a `from ... import ...` statement. + ImportFrom { + /// The `..foo` in `from ..foo import quux`. + parent: &'db ast::StmtImportFrom, + /// The `foo` in `from quux import foo`. + child: &'db Name, + }, +} + +impl<'db> ImportModuleName<'db> { + /// Converts the lazy representation of a module name into an + /// actual `ModuleName` that can be used for module resolution. + fn to_module_name(self, db: &'db dyn Db, importing_file: File) -> Option { + match self { + ImportModuleName::Import(name) => ModuleName::new(name), + ImportModuleName::ImportFrom { parent, child } => { + let mut module_name = + ModuleName::from_import_statement(db, importing_file, parent).ok()?; + let child_module_name = ModuleName::new(child)?; + module_name.extend(&child_module_name); + Some(module_name) + } + } + } +} + /// A visitor over all symbols in a single file. /// /// This guarantees that child symbols have a symbol ID greater @@ -444,6 +639,11 @@ struct SymbolVisitor<'db> { /// `__all__` idioms or there are any invalid elements in /// `__all__`. all_invalid: bool, + /// A collection of imports found while visiting the AST. + /// + /// These are used to help resolve references to modules + /// in some limited cases. + imports: Imports<'db>, } impl<'db> SymbolVisitor<'db> { @@ -459,6 +659,7 @@ impl<'db> SymbolVisitor<'db> { all_origin: None, all_names: FxHashSet::default(), all_invalid: false, + imports: Imports::default(), } } @@ -483,12 +684,28 @@ impl<'db> SymbolVisitor<'db> { // their position in a sequence. So when we filter some // out, we need to remap the identifiers. // - // N.B. The remapping could be skipped when `global_only` is + // We also want to deduplicate when `exports_only` is + // `true`. In particular, dealing with `__all__` can + // result in cycles, and we need to make sure our output + // is stable for that reason. + // + // N.B. The remapping could be skipped when `exports_only` is // true, since in that case, none of the symbols have a parent // ID by construction. let mut remap = IndexVec::with_capacity(self.symbols.len()); + let mut seen = self.exports_only.then(FxHashSet::default); let mut new = IndexVec::with_capacity(self.symbols.len()); for mut symbol in std::mem::take(&mut self.symbols) { + // If we're deduplicating and we've already seen + // this symbol, then skip it. + // + // FIXME: We should do this without copying every + // symbol name. ---AG + if let Some(ref mut seen) = seen { + if !seen.insert(symbol.name.clone()) { + continue; + } + } if !self.is_part_of_library_interface(&symbol) { remap.push(None); continue; @@ -519,7 +736,7 @@ impl<'db> SymbolVisitor<'db> { } } - fn visit_body(&mut self, body: &[ast::Stmt]) { + fn visit_body(&mut self, body: &'db [ast::Stmt]) { for stmt in body { self.visit_stmt(stmt); } @@ -649,6 +866,31 @@ impl<'db> SymbolVisitor<'db> { ast::Expr::List(ast::ExprList { elts, .. }) | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) | ast::Expr::Set(ast::ExprSet { elts, .. }) => self.add_all_names(elts), + // `__all__ += module.__all__` + // `__all__.extend(module.__all__)` + ast::Expr::Attribute(ast::ExprAttribute { .. }) => { + let Some(unqualified) = UnqualifiedName::from_expr(expr) else { + return false; + }; + let Some((&attr, rest)) = unqualified.segments().split_last() else { + return false; + }; + if attr != "__all__" { + return false; + } + let possible_module_name = Name::new(rest.join(".")); + let Some(symbols) = + self.imports + .get_module_symbols(self.db, self.file, &possible_module_name) + else { + return false; + }; + let Some(ref all) = symbols.all_names else { + return false; + }; + self.all_names.extend(all.iter().cloned()); + true + } _ => false, } } @@ -850,8 +1092,8 @@ impl<'db> SymbolVisitor<'db> { } } -impl SourceOrderVisitor<'_> for SymbolVisitor<'_> { - fn visit_stmt(&mut self, stmt: &ast::Stmt) { +impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> { + fn visit_stmt(&mut self, stmt: &'db ast::Stmt) { match stmt { ast::Stmt::FunctionDef(func_def) => { let kind = if self @@ -1023,6 +1265,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor<'_> { if self.in_function { return; } + self.imports.add_import(import); for alias in &import.names { self.add_import_alias(stmt, alias); } @@ -1038,6 +1281,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor<'_> { if self.in_function { return; } + self.imports.add_import_from(import_from); for alias in &import_from.names { if &alias.name == "*" { self.add_exported_from_wildcard(import_from); @@ -2010,6 +2254,363 @@ class X: ); } + #[test] + fn reexport_and_extend_from_submodule_import_statement_plus_equals() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "import foo + from foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__ += foo.__all__ + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_statement_extend() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "import foo + from foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__.extend(foo.__all__) + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_statement_alias() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "import foo as blah + from foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__ += blah.__all__ + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_statement_nested_alias() { + let test = PublicTestBuilder::default() + .source("parent/__init__.py", "") + .source( + "parent/foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "import parent.foo as blah + from parent.foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__ += blah.__all__ + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_from_statement_plus_equals() { + let test = PublicTestBuilder::default() + .source("parent/__init__.py", "") + .source( + "parent/foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "from parent import foo + from parent.foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__ += foo.__all__ + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_from_statement_nested_module_reference() { + let test = PublicTestBuilder::default() + .source("parent/__init__.py", "") + .source( + "parent/foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "import parent.foo + from parent.foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__ += parent.foo.__all__ + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_from_statement_extend() { + let test = PublicTestBuilder::default() + .source("parent/__init__.py", "") + .source( + "parent/foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "import parent.foo + from parent.foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__.extend(parent.foo.__all__) + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_from_statement_alias() { + let test = PublicTestBuilder::default() + .source("parent/__init__.py", "") + .source( + "parent/foo.py", + " + _ZQZQZQ = 1 + __all__ = ['_ZQZQZQ'] + ", + ) + .source( + "test.py", + "from parent import foo as blah + from parent.foo import * + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__ += blah.__all__ + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZQZQZQ :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_cycle1() { + let test = PublicTestBuilder::default() + .source( + "a.py", + "from b import * + import b + _ZAZAZA = 1 + __all__ = ['_ZAZAZA'] + __all__ += b.__all__ + ", + ) + .source( + "b.py", + " + from a import * + import a + _ZBZBZB = 1 + __all__ = ['_ZBZBZB'] + __all__ += a.__all__ + ", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("a.py"), + @r" + _ZBZBZB :: Constant + _ZAZAZA :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_statement_failure1() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + _ZFZFZF = 1 + __all__ = ['_ZFZFZF'] + ", + ) + .source( + "bar.py", + " + _ZBZBZB = 1 + __all__ = ['_ZBZBZB'] + ", + ) + .source( + "test.py", + "import foo + import bar + from foo import * + from bar import * + + foo = bar + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__.extend(foo.__all__) + ", + ) + .build(); + // In this test, we resolve `foo.__all__` to the `__all__` + // attribute in module `foo` instead of in `bar`. This is + // because we don't track redefinitions of imports (as of + // 2025-12-11). Handling this correctly would mean exporting + // `_ZBZBZB` instead of `_ZFZFZF`. + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZFZFZF :: Constant + _ZYZYZY :: Constant + ", + ); + } + + #[test] + fn reexport_and_extend_from_submodule_import_statement_failure2() { + let test = PublicTestBuilder::default() + .source( + "parent/__init__.py", + "import parent.foo as foo + __all__ = ['foo'] + ", + ) + .source( + "parent/foo.py", + " + _ZFZFZF = 1 + __all__ = ['_ZFZFZF'] + ", + ) + .source( + "test.py", + "from parent.foo import * + from parent import * + + _ZYZYZY = 1 + __all__ = ['_ZYZYZY'] + __all__.extend(foo.__all__) + ", + ) + .build(); + // This is not quite right either because we end up + // considering the `__all__` in `test.py` to be invalid. + // Namely, we don't pick up the `foo` that is in scope + // from the `from parent import *` import. The correct + // answer should just be `_ZFZFZF` and `_ZYZYZY`. + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + _ZFZFZF :: Constant + foo :: Module + _ZYZYZY :: Constant + __all__ :: Variable + ", + ); + } + fn matches(query: &str, symbol: &str) -> bool { super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol) }