[ty] Support `__all__ += submodule.__all__`

... and also `__all__.extend(submodule.__all__)`.

I originally left out support for this since I was unclear on whether
we'd really need it. But it turns out this is used somewhat frequently.
For example, in `numpy`.

See the comments on the new `Imports` type for how we approach this.
This commit is contained in:
Andrew Gallant 2025-12-10 15:51:10 -05:00 committed by Andrew Gallant
parent a2b138e789
commit 3ac58b47bd
1 changed files with 608 additions and 7 deletions

View File

@ -10,10 +10,10 @@ use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_index::{IndexVec, newtype_index};
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_python_ast::name::{Name, UnqualifiedName};
use ruff_python_ast::visitor::source_order::{self, SourceOrderVisitor};
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::FxHashSet;
use rustc_hash::{FxHashMap, FxHashSet};
use ty_project::Db;
use ty_python_semantic::{ModuleName, resolve_module};
@ -375,7 +375,11 @@ pub(crate) fn symbols_for_file(db: &dyn Db, file: File) -> FlatSymbols {
/// While callers can convert this into a hierarchical collection of
/// symbols, it won't result in anything meaningful since the flat list
/// returned doesn't include children.
#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)]
#[salsa::tracked(
returns(ref),
cycle_initial=symbols_for_file_global_only_cycle_initial,
heap_size=ruff_memory_usage::heap_size,
)]
pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbols {
let parsed = parsed_module(db, file);
let module = parsed.load(db);
@ -394,6 +398,14 @@ pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbo
visitor.into_flat_symbols()
}
fn symbols_for_file_global_only_cycle_initial(
_db: &dyn Db,
_id: salsa::Id,
_file: File,
) -> FlatSymbols {
FlatSymbols::default()
}
#[derive(Debug, Clone, PartialEq, Eq, get_size2::GetSize)]
struct SymbolTree {
parent: Option<SymbolId>,
@ -411,6 +423,189 @@ enum ImportKind {
Wildcard,
}
/// An abstraction for managing module scope imports.
///
/// This is meant to recognize the following idioms for updating
/// `__all__` in module scope:
///
/// ```ignore
/// __all__ += submodule.__all__
/// __all__.extend(submodule.__all__)
/// ```
///
/// # Correctness
///
/// The approach used here is not correct 100% of the time.
/// For example, it is somewhat easy to defeat it:
///
/// ```ignore
/// from numpy import *
/// from importlib import resources
/// import numpy as np
/// np = resources
/// __all__ = []
/// __all__ += np.__all__
/// ```
///
/// In this example, `np` will still be resolved to the `numpy`
/// module instead of the `importlib.resources` module. Namely, this
/// abstraction doesn't track all definitions. This would result in a
/// silently incorrect `__all__`.
///
/// This abstraction does handle the case when submodules are imported.
/// Namely, we do get this case correct:
///
/// ```ignore
/// from importlib.resources import *
/// from importlib import resources
/// __all__ = []
/// __all__ += resources.__all__
/// ```
///
/// We do this by treating all imports in a `from ... import ...`
/// statement as *possible* modules. Then when we lookup `resources`,
/// we attempt to resolve it to an actual module. If that fails, then
/// we consider `__all__` invalid.
///
/// There are likely many many other cases that we don't handle as
/// well, which ty does (it has its own `__all__` parsing using types
/// to deal with this case). We can add handling for those as they
/// come up in real world examples.
///
/// # Performance
///
/// This abstraction recognizes that, compared to all possible imports,
/// it is very rare to use one of them to update `__all__`. Therefore,
/// we are careful not to do too much work up-front (like eagerly
/// manifesting `ModuleName` values).
#[derive(Clone, Debug, Default, get_size2::GetSize)]
struct Imports<'db> {
/// A map from the name that a module is available
/// under to its actual module name (and our level
/// of certainty that it ought to be treated as a module).
module_names: FxHashMap<&'db str, ImportModuleKind<'db>>,
}
impl<'db> Imports<'db> {
/// Track the imports from the given `import ...` statement.
fn add_import(&mut self, import: &'db ast::StmtImport) {
for alias in &import.names {
let asname = alias
.asname
.as_ref()
.map(|ident| &ident.id)
.unwrap_or(&alias.name.id);
let module_name = ImportModuleName::Import(&alias.name.id);
self.module_names
.insert(asname, ImportModuleKind::Definitive(module_name));
}
}
/// Track the imports from the given `from ... import ...` statement.
fn add_import_from(&mut self, import_from: &'db ast::StmtImportFrom) {
for alias in &import_from.names {
if &alias.name == "*" {
// FIXME: We'd ideally include the names
// imported from the module, but we don't
// want to do this eagerly. So supporting
// this requires more infrastructure in
// `Imports`.
continue;
}
let asname = alias
.asname
.as_ref()
.map(|ident| &ident.id)
.unwrap_or(&alias.name.id);
let module_name = ImportModuleName::ImportFrom {
parent: import_from,
child: &alias.name.id,
};
self.module_names
.insert(asname, ImportModuleKind::Possible(module_name));
}
}
/// Return the symbols exported by the module referred to by `name`.
///
/// e.g., This can be used to resolve `__all__ += submodule.__all__`,
/// where `name` is `submodule`.
fn get_module_symbols(
&self,
db: &'db dyn Db,
importing_file: File,
name: &Name,
) -> Option<&'db FlatSymbols> {
let module_name = match self.module_names.get(name.as_str())? {
ImportModuleKind::Definitive(name) | ImportModuleKind::Possible(name) => {
name.to_module_name(db, importing_file)?
}
};
let module = resolve_module(db, importing_file, &module_name)?;
Some(symbols_for_file_global_only(db, module.file(db)?))
}
}
/// Describes the level of certainty that an import is a module.
///
/// For example, `import foo`, then `foo` is definitively a module.
/// But `from quux import foo`, then `quux.foo` is possibly a module.
#[derive(Debug, Clone, Copy, get_size2::GetSize)]
enum ImportModuleKind<'db> {
Definitive(ImportModuleName<'db>),
Possible(ImportModuleName<'db>),
}
/// A representation of something that can be turned into a
/// `ModuleName`.
///
/// We don't do this eagerly, and instead represent the constituent
/// pieces, in order to avoid the work needed to build a `ModuleName`.
/// In particular, it is somewhat rare for the visitor to need
/// to access the imports found in a module. At time of writing
/// (2025-12-10), this only happens when referencing a submodule
/// to augment an `__all__` definition. For example, as found in
/// `matplotlib`:
///
/// ```ignore
/// import numpy as np
/// __all__ = ['rand', 'randn', 'repmat']
/// __all__ += np.__all__
/// ```
///
/// This construct is somewhat rare and it would be sad to allocate a
/// `ModuleName` for every imported item unnecessarily.
#[derive(Debug, Clone, Copy, get_size2::GetSize)]
enum ImportModuleName<'db> {
/// The `foo` in `import quux, foo as blah, baz`.
Import(&'db Name),
/// A possible module in a `from ... import ...` statement.
ImportFrom {
/// The `..foo` in `from ..foo import quux`.
parent: &'db ast::StmtImportFrom,
/// The `foo` in `from quux import foo`.
child: &'db Name,
},
}
impl<'db> ImportModuleName<'db> {
/// Converts the lazy representation of a module name into an
/// actual `ModuleName` that can be used for module resolution.
fn to_module_name(self, db: &'db dyn Db, importing_file: File) -> Option<ModuleName> {
match self {
ImportModuleName::Import(name) => ModuleName::new(name),
ImportModuleName::ImportFrom { parent, child } => {
let mut module_name =
ModuleName::from_import_statement(db, importing_file, parent).ok()?;
let child_module_name = ModuleName::new(child)?;
module_name.extend(&child_module_name);
Some(module_name)
}
}
}
}
/// A visitor over all symbols in a single file.
///
/// This guarantees that child symbols have a symbol ID greater
@ -444,6 +639,11 @@ struct SymbolVisitor<'db> {
/// `__all__` idioms or there are any invalid elements in
/// `__all__`.
all_invalid: bool,
/// A collection of imports found while visiting the AST.
///
/// These are used to help resolve references to modules
/// in some limited cases.
imports: Imports<'db>,
}
impl<'db> SymbolVisitor<'db> {
@ -459,6 +659,7 @@ impl<'db> SymbolVisitor<'db> {
all_origin: None,
all_names: FxHashSet::default(),
all_invalid: false,
imports: Imports::default(),
}
}
@ -483,12 +684,28 @@ impl<'db> SymbolVisitor<'db> {
// their position in a sequence. So when we filter some
// out, we need to remap the identifiers.
//
// N.B. The remapping could be skipped when `global_only` is
// We also want to deduplicate when `exports_only` is
// `true`. In particular, dealing with `__all__` can
// result in cycles, and we need to make sure our output
// is stable for that reason.
//
// N.B. The remapping could be skipped when `exports_only` is
// true, since in that case, none of the symbols have a parent
// ID by construction.
let mut remap = IndexVec::with_capacity(self.symbols.len());
let mut seen = self.exports_only.then(FxHashSet::default);
let mut new = IndexVec::with_capacity(self.symbols.len());
for mut symbol in std::mem::take(&mut self.symbols) {
// If we're deduplicating and we've already seen
// this symbol, then skip it.
//
// FIXME: We should do this without copying every
// symbol name. ---AG
if let Some(ref mut seen) = seen {
if !seen.insert(symbol.name.clone()) {
continue;
}
}
if !self.is_part_of_library_interface(&symbol) {
remap.push(None);
continue;
@ -519,7 +736,7 @@ impl<'db> SymbolVisitor<'db> {
}
}
fn visit_body(&mut self, body: &[ast::Stmt]) {
fn visit_body(&mut self, body: &'db [ast::Stmt]) {
for stmt in body {
self.visit_stmt(stmt);
}
@ -649,6 +866,31 @@ impl<'db> SymbolVisitor<'db> {
ast::Expr::List(ast::ExprList { elts, .. })
| ast::Expr::Tuple(ast::ExprTuple { elts, .. })
| ast::Expr::Set(ast::ExprSet { elts, .. }) => self.add_all_names(elts),
// `__all__ += module.__all__`
// `__all__.extend(module.__all__)`
ast::Expr::Attribute(ast::ExprAttribute { .. }) => {
let Some(unqualified) = UnqualifiedName::from_expr(expr) else {
return false;
};
let Some((&attr, rest)) = unqualified.segments().split_last() else {
return false;
};
if attr != "__all__" {
return false;
}
let possible_module_name = Name::new(rest.join("."));
let Some(symbols) =
self.imports
.get_module_symbols(self.db, self.file, &possible_module_name)
else {
return false;
};
let Some(ref all) = symbols.all_names else {
return false;
};
self.all_names.extend(all.iter().cloned());
true
}
_ => false,
}
}
@ -850,8 +1092,8 @@ impl<'db> SymbolVisitor<'db> {
}
}
impl SourceOrderVisitor<'_> for SymbolVisitor<'_> {
fn visit_stmt(&mut self, stmt: &ast::Stmt) {
impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> {
fn visit_stmt(&mut self, stmt: &'db ast::Stmt) {
match stmt {
ast::Stmt::FunctionDef(func_def) => {
let kind = if self
@ -1023,6 +1265,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor<'_> {
if self.in_function {
return;
}
self.imports.add_import(import);
for alias in &import.names {
self.add_import_alias(stmt, alias);
}
@ -1038,6 +1281,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor<'_> {
if self.in_function {
return;
}
self.imports.add_import_from(import_from);
for alias in &import_from.names {
if &alias.name == "*" {
self.add_exported_from_wildcard(import_from);
@ -2010,6 +2254,363 @@ class X:
);
}
#[test]
fn reexport_and_extend_from_submodule_import_statement_plus_equals() {
let test = PublicTestBuilder::default()
.source(
"foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"import foo
from foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__ += foo.__all__
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_statement_extend() {
let test = PublicTestBuilder::default()
.source(
"foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"import foo
from foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__.extend(foo.__all__)
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_statement_alias() {
let test = PublicTestBuilder::default()
.source(
"foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"import foo as blah
from foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__ += blah.__all__
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_statement_nested_alias() {
let test = PublicTestBuilder::default()
.source("parent/__init__.py", "")
.source(
"parent/foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"import parent.foo as blah
from parent.foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__ += blah.__all__
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_from_statement_plus_equals() {
let test = PublicTestBuilder::default()
.source("parent/__init__.py", "")
.source(
"parent/foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"from parent import foo
from parent.foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__ += foo.__all__
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_from_statement_nested_module_reference() {
let test = PublicTestBuilder::default()
.source("parent/__init__.py", "")
.source(
"parent/foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"import parent.foo
from parent.foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__ += parent.foo.__all__
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_from_statement_extend() {
let test = PublicTestBuilder::default()
.source("parent/__init__.py", "")
.source(
"parent/foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"import parent.foo
from parent.foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__.extend(parent.foo.__all__)
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_from_statement_alias() {
let test = PublicTestBuilder::default()
.source("parent/__init__.py", "")
.source(
"parent/foo.py",
"
_ZQZQZQ = 1
__all__ = ['_ZQZQZQ']
",
)
.source(
"test.py",
"from parent import foo as blah
from parent.foo import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__ += blah.__all__
",
)
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZQZQZQ :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_cycle1() {
let test = PublicTestBuilder::default()
.source(
"a.py",
"from b import *
import b
_ZAZAZA = 1
__all__ = ['_ZAZAZA']
__all__ += b.__all__
",
)
.source(
"b.py",
"
from a import *
import a
_ZBZBZB = 1
__all__ = ['_ZBZBZB']
__all__ += a.__all__
",
)
.build();
insta::assert_snapshot!(
test.exports_for("a.py"),
@r"
_ZBZBZB :: Constant
_ZAZAZA :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_statement_failure1() {
let test = PublicTestBuilder::default()
.source(
"foo.py",
"
_ZFZFZF = 1
__all__ = ['_ZFZFZF']
",
)
.source(
"bar.py",
"
_ZBZBZB = 1
__all__ = ['_ZBZBZB']
",
)
.source(
"test.py",
"import foo
import bar
from foo import *
from bar import *
foo = bar
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__.extend(foo.__all__)
",
)
.build();
// In this test, we resolve `foo.__all__` to the `__all__`
// attribute in module `foo` instead of in `bar`. This is
// because we don't track redefinitions of imports (as of
// 2025-12-11). Handling this correctly would mean exporting
// `_ZBZBZB` instead of `_ZFZFZF`.
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZFZFZF :: Constant
_ZYZYZY :: Constant
",
);
}
#[test]
fn reexport_and_extend_from_submodule_import_statement_failure2() {
let test = PublicTestBuilder::default()
.source(
"parent/__init__.py",
"import parent.foo as foo
__all__ = ['foo']
",
)
.source(
"parent/foo.py",
"
_ZFZFZF = 1
__all__ = ['_ZFZFZF']
",
)
.source(
"test.py",
"from parent.foo import *
from parent import *
_ZYZYZY = 1
__all__ = ['_ZYZYZY']
__all__.extend(foo.__all__)
",
)
.build();
// This is not quite right either because we end up
// considering the `__all__` in `test.py` to be invalid.
// Namely, we don't pick up the `foo` that is in scope
// from the `from parent import *` import. The correct
// answer should just be `_ZFZFZF` and `_ZYZYZY`.
insta::assert_snapshot!(
test.exports_for("test.py"),
@r"
_ZFZFZF :: Constant
foo :: Module
_ZYZYZY :: Constant
__all__ :: Variable
",
);
}
fn matches(query: &str, symbol: &str) -> bool {
super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol)
}