From 8c72b296c9895b9e40dacf69c584a62b84f09803 Mon Sep 17 00:00:00 2001 From: Andrew Gallant Date: Tue, 25 Nov 2025 14:03:19 -0500 Subject: [PATCH] [ty] Add support for re-exports and `__all__` to auto-import This commit (mostly) re-implements the support for `__all__` in ty-proper, but inside the auto-import AST scanner. When `__all__` isn't present in a module, we fall back to conventions to determine whether a symbol is exported or not: https://docs.python.org/3/library/index.html However, in keeping with current practice for non-auto-import completions, we continue to provide sunder and dunder names as re-exports. When `__all__` is present, we respect it strictly. That is, a symbol is exported *if and only if* it's in `__all__`. This is somewhat stricter than pylance seemingly is. I felt like it was a good idea to start here, and we can relax it based on user demand (perhaps through a setting). --- crates/ty_ide/src/symbols.rs | 1308 ++++++++++++++++- .../snapshots/e2e__notebook__auto_import.snap | 4 +- .../e2e__notebook__auto_import_docstring.snap | 4 +- ...2e__notebook__auto_import_from_future.snap | 4 +- .../e2e__notebook__auto_import_same_cell.snap | 4 +- 5 files changed, 1293 insertions(+), 31 deletions(-) diff --git a/crates/ty_ide/src/symbols.rs b/crates/ty_ide/src/symbols.rs index 295ccad5c3..a80a9ed56d 100644 --- a/crates/ty_ide/src/symbols.rs +++ b/crates/ty_ide/src/symbols.rs @@ -13,7 +13,9 @@ use ruff_python_ast as ast; use ruff_python_ast::name::Name; use ruff_python_ast::visitor::source_order::{self, SourceOrderVisitor}; use ruff_text_size::{Ranged, TextRange}; +use rustc_hash::FxHashSet; use ty_project::Db; +use ty_python_semantic::{ModuleName, resolve_module}; use crate::completion::CompletionKind; @@ -111,7 +113,13 @@ impl PartialEq for QueryPattern { /// A flat list of indexed symbols for a single file. #[derive(Clone, Debug, Default, PartialEq, Eq, get_size2::GetSize)] pub struct FlatSymbols { + /// The symbols exported by a module. symbols: IndexVec, + /// The names found in an `__all__` for a module. + /// + /// This is `None` if the module has no `__all__` at module + /// scope. + all_names: Option>, } impl FlatSymbols { @@ -351,16 +359,9 @@ pub(crate) fn symbols_for_file(db: &dyn Db, file: File) -> FlatSymbols { let parsed = parsed_module(db, file); let module = parsed.load(db); - let mut visitor = SymbolVisitor { - symbols: IndexVec::new(), - symbol_stack: vec![], - in_function: false, - global_only: false, - }; + let mut visitor = SymbolVisitor::tree(db, file); visitor.visit_body(&module.syntax().body); - FlatSymbols { - symbols: visitor.symbols, - } + visitor.into_flat_symbols() } /// Returns a flat list of *only global* symbols in the file given. @@ -373,12 +374,7 @@ pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbo let parsed = parsed_module(db, file); let module = parsed.load(db); - let mut visitor = SymbolVisitor { - symbols: IndexVec::new(), - symbol_stack: vec![], - in_function: false, - global_only: true, - }; + let mut visitor = SymbolVisitor::globals(db, file); visitor.visit_body(&module.syntax().body); if file @@ -389,10 +385,7 @@ pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbo // Eagerly clear ASTs of third party files. parsed.clear(); } - - FlatSymbols { - symbols: visitor.symbols, - } + visitor.into_flat_symbols() } #[derive(Debug, Clone, PartialEq, Eq, get_size2::GetSize)] @@ -402,27 +395,122 @@ struct SymbolTree { kind: SymbolKind, name_range: TextRange, full_range: TextRange, + import_kind: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, get_size2::GetSize)] +enum ImportKind { + Normal, + RedundantAlias, + Wildcard, } /// A visitor over all symbols in a single file. /// /// This guarantees that child symbols have a symbol ID greater /// than all of its parents. -struct SymbolVisitor { +#[allow(clippy::struct_excessive_bools)] +struct SymbolVisitor<'db> { + db: &'db dyn Db, + file: File, symbols: IndexVec, symbol_stack: Vec, - /// Track if we're currently inside a function (to exclude local variables) + /// Track if we're currently inside a function at any point. + /// + /// This is true even when we're inside a class definition + /// that is inside a class. in_function: bool, + /// Track if we're currently inside a class at any point. + /// + /// This is true even when we're inside a function definition + /// that is inside a class. + in_class: bool, global_only: bool, + /// The origin of an `__all__` variable, if found. + all_origin: Option, + /// A set of names extracted from `__all__`. + all_names: FxHashSet, + /// A flag indicating whether the module uses unrecognized + /// `__all__` idioms or there are any invalid elements in + /// `__all__`. + all_invalid: bool, } -impl SymbolVisitor { +impl<'db> SymbolVisitor<'db> { + fn tree(db: &'db dyn Db, file: File) -> Self { + Self { + db, + file, + symbols: IndexVec::new(), + symbol_stack: vec![], + in_function: false, + in_class: false, + global_only: false, + all_origin: None, + all_names: FxHashSet::default(), + all_invalid: false, + } + } + + fn globals(db: &'db dyn Db, file: File) -> Self { + Self { + global_only: true, + ..Self::tree(db, file) + } + } + + fn into_flat_symbols(mut self) -> FlatSymbols { + // We want to filter out some of the symbols we collected. + // Specifically, to respect conventions around library + // interface. + // + // But, we always assigned IDs to each symbol based on + // their position in a sequence. So when we filter some + // out, we need to remap the identifiers. + // + // N.B. The remapping could be skipped when `global_only` is + // true, since in that case, none of the symbols have a parent + // ID by construction. + let mut remap = IndexVec::with_capacity(self.symbols.len()); + let mut new = IndexVec::with_capacity(self.symbols.len()); + for mut symbol in std::mem::take(&mut self.symbols) { + if !self.is_part_of_library_interface(&symbol) { + remap.push(None); + continue; + } + + if let Some(ref mut parent) = symbol.parent { + // OK because the visitor guarantees that + // all parents have IDs less than their + // children. So its ID has already been + // remapped. + if let Some(new_parent) = remap[*parent] { + *parent = new_parent; + } else { + // The parent symbol was dropped, so + // all of its children should be as + // well. + remap.push(None); + continue; + } + } + let new_id = new.next_index(); + remap.push(Some(new_id)); + new.push(symbol); + } + FlatSymbols { + symbols: new, + all_names: self.all_origin.map(|_| self.all_names), + } + } + fn visit_body(&mut self, body: &[ast::Stmt]) { for stmt in body { self.visit_stmt(stmt); } } + /// Add a new symbol and return its ID. fn add_symbol(&mut self, mut symbol: SymbolTree) -> SymbolId { if let Some(&parent_id) = self.symbol_stack.last() { symbol.parent = Some(parent_id); @@ -436,6 +524,7 @@ impl SymbolVisitor { symbol_id } + /// Adds a symbol introduced via an assignment. fn add_assignment(&mut self, stmt: &ast::Stmt, name: &ast::ExprName) -> SymbolId { let kind = if Self::is_constant_name(name.id.as_str()) { SymbolKind::Constant @@ -454,10 +543,222 @@ impl SymbolVisitor { kind, name_range: name.range(), full_range: stmt.range(), + import_kind: None, }; self.add_symbol(symbol) } + /// Adds a symbol introduced via an import `stmt`. + fn add_import_alias(&mut self, stmt: &ast::Stmt, alias: &ast::Alias) -> SymbolId { + let name = alias.asname.as_ref().unwrap_or(&alias.name); + let kind = if stmt.is_import_stmt() { + SymbolKind::Module + } else if Self::is_constant_name(name.as_str()) { + SymbolKind::Constant + } else { + SymbolKind::Variable + }; + let re_export = Some( + if alias.asname.as_ref().map(ast::Identifier::as_str) == Some(alias.name.as_str()) { + ImportKind::RedundantAlias + } else { + ImportKind::Normal + }, + ); + self.add_symbol(SymbolTree { + parent: None, + name: name.id.to_string(), + kind, + name_range: name.range(), + full_range: stmt.range(), + import_kind: re_export, + }) + } + + /// Extracts `__all__` names from the given assignment. + /// + /// If the assignment isn't for `__all__`, then this is a no-op. + fn add_all_assignment(&mut self, targets: &[ast::Expr], value: Option<&ast::Expr>) { + if self.in_function || self.in_class { + return; + } + let Some(target) = targets.first() else { + return; + }; + if !is_dunder_all(target) { + return; + } + + let Some(value) = value else { return }; + match *value { + // `__all__ = [...]` + // `__all__ = (...)` + ast::Expr::List(ast::ExprList { ref elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { ref elts, .. }) => { + self.update_all_origin(DunderAllOrigin::CurrentModule); + if !self.add_all_names(elts) { + self.all_invalid = true; + } + } + _ => { + self.all_invalid = true; + } + } + } + + /// Extends the current set of names with the names from the + /// given expression which currently must be a list/tuple/set of + /// string-literal names. This currently does not support using a + /// submodule's `__all__` variable. + /// + /// Returns `true` if the expression is a valid list/tuple/set or + /// module `__all__`, `false` otherwise. + /// + /// N.B. Supporting all instances of `__all__ += submodule.__all__` + /// and `__all__.extend(submodule.__all__)` is likely difficult + /// in this context. Namely, `submodule` needs to be resolved + /// to a particular module. ty proper can do this (by virtue + /// of inferring the type of `submodule`). With that said, we + /// could likely support a subset of cases here without too much + /// ceremony. ---AG + fn extend_all(&mut self, expr: &ast::Expr) -> bool { + match expr { + // `__all__ += [...]` + // `__all__ += (...)` + // `__all__ += {...}` + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) + | ast::Expr::Set(ast::ExprSet { elts, .. }) => self.add_all_names(elts), + _ => false, + } + } + + /// Processes a call idiom for `__all__` and updates the set of + /// names accordingly. + /// + /// Returns `true` if the call idiom is recognized and valid, + /// `false` otherwise. + fn update_all_by_call_idiom( + &mut self, + function_name: &ast::Identifier, + arguments: &ast::Arguments, + ) -> bool { + if arguments.len() != 1 { + return false; + } + let Some(argument) = arguments.find_positional(0) else { + return false; + }; + match function_name.as_str() { + // `__all__.extend([...])` + // `__all__.extend(module.__all__)` + "extend" => { + if !self.extend_all(argument) { + return false; + } + } + // `__all__.append(...)` + "append" => { + let Some(name) = create_all_name(argument) else { + return false; + }; + self.all_names.insert(name); + } + // `__all__.remove(...)` + "remove" => { + let Some(name) = create_all_name(argument) else { + return false; + }; + self.all_names.remove(&name); + } + _ => return false, + } + true + } + + /// Adds all of the names exported from the module + /// imported by `import_from`. i.e., This implements + /// `from module import *` semantics. + fn add_exported_from_wildcard(&mut self, import_from: &ast::StmtImportFrom) { + let Some(symbols) = self.get_names_from_wildcard(import_from) else { + self.all_invalid = true; + return; + }; + self.symbols + .extend(symbols.symbols.iter().filter_map(|symbol| { + // If there's no `__all__`, then names with an underscore + // are never pulled in via a wildcard import. Otherwise, + // we defer to `__all__` filtering. + if symbols.all_names.is_none() && symbol.name.starts_with('_') { + return None; + } + let mut symbol = symbol.clone(); + symbol.import_kind = Some(ImportKind::Wildcard); + Some(symbol) + })); + // If the imported module defines an `__all__` AND `__all__` is + // in `__all__`, then the importer gets it too. + if let Some(ref all) = symbols.all_names + && all.contains("__all__") + { + self.update_all_origin(DunderAllOrigin::StarImport); + self.all_names.extend(all.iter().cloned()); + } + } + + /// Adds `__all__` from the module imported by `import_from`. i.e., + /// This implements `from module import __all__` semantics. + fn add_all_from_import(&mut self, import_from: &ast::StmtImportFrom) { + let Some(symbols) = self.get_names_from_wildcard(import_from) else { + self.all_invalid = true; + return; + }; + // If the imported module defines an `__all__`, + // then the importer gets it too. + if let Some(ref all) = symbols.all_names { + self.update_all_origin(DunderAllOrigin::ExternalModule); + self.all_names.extend(all.iter().cloned()); + } + } + + /// Returns the exported symbols (along with `__all__`) from the + /// module imported in `import_from`. + fn get_names_from_wildcard( + &self, + import_from: &ast::StmtImportFrom, + ) -> Option<&'db FlatSymbols> { + let module_name = + ModuleName::from_import_statement(self.db, self.file, import_from).ok()?; + let module = resolve_module(self.db, self.file, &module_name)?; + Some(symbols_for_file_global_only(self.db, module.file(self.db)?)) + } + + /// Add valid names from `__all__` to the set of existing `__all__` + /// names. + /// + /// Returns `false` if any of the names are invalid. + fn add_all_names(&mut self, exprs: &[ast::Expr]) -> bool { + for expr in exprs { + let Some(name) = create_all_name(expr) else { + return false; + }; + self.all_names.insert(name); + } + true + } + + /// Updates the origin of `__all__` in the current module. + /// + /// This will clear existing names if the origin is changed to + /// mimic the behavior of overriding `__all__` in the current + /// module. + fn update_all_origin(&mut self, origin: DunderAllOrigin) { + if self.all_origin.is_some() { + self.all_names.clear(); + } + self.all_origin = Some(origin); + } + fn push_symbol(&mut self, symbol: SymbolTree) { let symbol_id = self.add_symbol(symbol); self.symbol_stack.push(symbol_id); @@ -477,9 +778,62 @@ impl SymbolVisitor { fn is_constant_name(name: &str) -> bool { name.chars().all(|c| c.is_ascii_uppercase() || c == '_') } + + /// This routine determines whether the given symbol should be + /// considered part of the public API of this module. The given + /// symbol should defined or imported into this module. + /// + /// See: + fn is_part_of_library_interface(&self, symbol: &SymbolTree) -> bool { + // If this is a child of something else, then we always + // defer its visibility to the parent. + if symbol.parent.is_some() { + return true; + } + + // When there's no `__all__`, we use conventions to determine + // if a name should be part of the exported API of a module + // or not. When there is `__all__`, we currently follow it + // strictly. + if self.all_origin.is_some() { + // If `__all__` is somehow invalid, ignore it and fall + // through as-if `__all__` didn't exist. + if self.all_invalid { + tracing::debug!("Invalid `__all__` in `{}`", self.file.path(self.db)); + } else { + return self.all_names.contains(&*symbol.name); + } + } + + // "Imported symbols are considered private by default. A fixed + // set of import forms re-export imported symbols." Specifically: + // + // * `import X as X` + // * `from Y import X as X` + // * `from Y import *` + if let Some(kind) = symbol.import_kind { + return match kind { + ImportKind::RedundantAlias | ImportKind::Wildcard => true, + ImportKind::Normal => false, + }; + } + // "Symbols whose names begin with an underscore (but are not + // dunder names) are considered private." + // + // ... however, we currently include these as part of the public + // API. The only extant (2025-12-03) consumer is completions, and + // completions will rank these names lower than others. + if symbol.name.starts_with('_') + && !(symbol.name.starts_with("__") && symbol.name.ends_with("__")) + { + return true; + } + // ... otherwise, it's exported! + true + } } -impl SourceOrderVisitor<'_> for SymbolVisitor { +impl SourceOrderVisitor<'_> for SymbolVisitor<'_> { fn visit_stmt(&mut self, stmt: &ast::Stmt) { match stmt { ast::Stmt::FunctionDef(func_def) => { @@ -502,6 +856,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { kind, name_range: func_def.name.range(), full_range: stmt.range(), + import_kind: None, }; if self.global_only { @@ -530,6 +885,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { kind: SymbolKind::Class, name_range: class_def.name.range(), full_range: stmt.range(), + import_kind: None, }; if self.global_only { @@ -538,11 +894,20 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { return; } + // Mark that we're entering a class scope + let was_in_class = self.in_class; + self.in_class = true; + self.push_symbol(symbol); source_order::walk_stmt(self, stmt); self.pop_symbol(); + + // Restore the previous class scope state + self.in_class = was_in_class; } ast::Stmt::Assign(assign) => { + self.add_all_assignment(&assign.targets, Some(&assign.value)); + // Include assignments only when we're in global or class scope if self.in_function { return; @@ -555,6 +920,11 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { } } ast::Stmt::AnnAssign(ann_assign) => { + self.add_all_assignment( + std::slice::from_ref(&ann_assign.target), + ann_assign.value.as_deref(), + ); + // Include assignments only when we're in global or class scope if self.in_function { return; @@ -564,6 +934,89 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { }; self.add_assignment(stmt, name); } + ast::Stmt::AugAssign(ast::StmtAugAssign { + target, op, value, .. + }) => { + if self.all_origin.is_none() { + // We can't update `__all__` if it doesn't already + // exist. + return; + } + if !is_dunder_all(target) { + return; + } + // Anything other than `+=` is not valid. + if !matches!(op, ast::Operator::Add) { + self.all_invalid = true; + return; + } + if !self.extend_all(value) { + self.all_invalid = true; + } + } + ast::Stmt::Expr(expr) => { + if self.all_origin.is_none() { + // We can't update `__all__` if it doesn't already exist. + return; + } + let Some(ast::ExprCall { + func, arguments, .. + }) = expr.value.as_call_expr() + else { + return; + }; + let Some(ast::ExprAttribute { + value, + attr, + ctx: ast::ExprContext::Load, + .. + }) = func.as_attribute_expr() + else { + return; + }; + if !is_dunder_all(value) { + return; + } + if !self.update_all_by_call_idiom(attr, arguments) { + self.all_invalid = true; + } + + source_order::walk_stmt(self, stmt); + } + ast::Stmt::Import(import) => { + // We only consider imports in global scope. + if self.in_function { + return; + } + for alias in &import.names { + self.add_import_alias(stmt, alias); + } + } + ast::Stmt::ImportFrom(import_from) => { + // We only consider imports in global scope. + if self.in_function { + return; + } + for alias in &import_from.names { + if &alias.name == "*" { + self.add_exported_from_wildcard(import_from); + } else { + if &alias.name == "__all__" + && alias + .asname + .as_ref() + .is_none_or(|asname| asname == "__all__") + { + self.add_all_from_import(import_from); + } + self.add_import_alias(stmt, alias); + } + } + } + // FIXME: We don't currently try to evaluate `if` + // statements. We just assume that all `if` statements are + // always `True`. This applies to symbols in general but + // also `__all__`. _ => { source_order::walk_stmt(self, stmt); } @@ -575,6 +1028,29 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { fn visit_expr(&mut self, _expr: &ast::Expr) {} } +/// Represents where an `__all__` has been defined. +#[derive(Debug, Clone)] +enum DunderAllOrigin { + /// The `__all__` variable is defined in the current module. + CurrentModule, + /// The `__all__` variable is imported from another module. + ExternalModule, + /// The `__all__` variable is imported from a module via a `*`-import. + StarImport, +} + +/// Checks if the given expression is a name expression for `__all__`. +fn is_dunder_all(expr: &ast::Expr) -> bool { + matches!(expr, ast::Expr::Name(ast::ExprName { id, .. }) if id == "__all__") +} + +/// Create and return a string representing a name from the given +/// expression, or `None` if it is an invalid expression for a +/// `__all__` element. +fn create_all_name(expr: &ast::Expr) -> Option { + Some(expr.as_string_literal_expr()?.value.to_str().into()) +} + #[cfg(test)] mod tests { use camino::Utf8Component; @@ -641,6 +1117,25 @@ def quux(): ); } + /// The typing spec says that names beginning with an underscore + /// ought to be considered unexported[1]. However, at present, we + /// currently include them in completions but rank them lower than + /// non-underscore names. So this tests that we return underscore + /// names. + /// + /// [1]: https://typing.python.org/en/latest/spec/distributing.html#library-interface-public-and-private-symbols + #[test] + fn exports_underscore() { + insta::assert_snapshot!( + public_test("\ +_foo = 1 +").exports(), + @r" + _foo :: Variable + ", + ); + } + #[test] fn exports_conditional_true() { insta::assert_snapshot!( @@ -707,6 +1202,773 @@ if TYPE_CHECKING: ); } + #[test] + fn exports_conditional_always_else() { + // FIXME: This shouldn't include `bar`. + insta::assert_snapshot!( + public_test("\ +foo = 1 +bar = 1 +if True: + __all__ = ['foo'] +else: + __all__ = ['foo', 'bar'] +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_all_overwrites_previous() { + insta::assert_snapshot!( + public_test("\ +foo = 1 +bar = 1 +__all__ = ['foo'] +__all__ = ['foo', 'bar'] +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_import_no_reexport() { + insta::assert_snapshot!( + public_test("\ +import collections +").exports(), + @r"", + ); + } + + #[test] + fn exports_import_as_no_reexport() { + insta::assert_snapshot!( + public_test("\ +import numpy as np +").exports(), + @r"", + ); + } + + #[test] + fn exports_from_import_no_reexport() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +").exports(), + @r"", + ); + } + + #[test] + fn exports_from_import_as_no_reexport() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict as dd +").exports(), + @r"", + ); + } + + #[test] + fn exports_import_reexport() { + insta::assert_snapshot!( + public_test("\ +import numpy as numpy +").exports(), + @"numpy :: Module", + ); + } + + #[test] + fn exports_from_import_reexport() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict as defaultdict +").exports(), + @"defaultdict :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_assignment() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = ['defaultdict'] +").exports(), + @"defaultdict :: Variable", + ); + + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = ('defaultdict',) +").exports(), + @"defaultdict :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_annotated_assignment() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__: list[str] = ['defaultdict'] +").exports(), + @"defaultdict :: Variable", + ); + + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__: tuple[str, ...] = ('defaultdict',) +").exports(), + @"defaultdict :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_augmented_assignment() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__ += ['defaultdict'] +").exports(), + @"defaultdict :: Variable", + ); + + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__ += ('defaultdict',) +").exports(), + @"defaultdict :: Variable", + ); + + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__ += {'defaultdict'} +").exports(), + @"defaultdict :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_invalid_augmented_assignment() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ += ['defaultdict'] +").exports(), + @"", + ); + } + + #[test] + fn exports_from_import_all_reexport_extend() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__.extend(['defaultdict']) +").exports(), + @"defaultdict :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_invalid_extend() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__.extend(['defaultdict']) +").exports(), + @r"", + ); + } + + #[test] + fn exports_from_import_all_reexport_append() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__.append('defaultdict') +").exports(), + @"defaultdict :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_plus_equals() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__ += ['defaultdict'] +").exports(), + @"defaultdict :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_star_equals() { + // Confirm that this doesn't work. Only `__all__ += ...` should + // be recognized. This makes the symbol visitor consider + // `__all__` invalid and thus ignore it. And this in turn lets + // `__all__` be exported. This seems like a somewhat degenerate + // case, but is a consequence of us treating sunder and dunder + // symbols as exported when `__all__` isn't present. + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__ *= ['defaultdict'] +").exports(), + @"__all__ :: Variable", + ); + } + + #[test] + fn exports_from_import_all_reexport_remove() { + insta::assert_snapshot!( + public_test("\ +from collections import defaultdict +__all__ = [] +__all__.remove('defaultdict') +").exports(), + @"", + ); + } + + #[test] + fn exports_nested_all() { + insta::assert_snapshot!( + public_test(r#"\ +bar = 1 +baz = 1 +__all__ = [] + +def foo(): + __all__.append("bar") + +class X: + def method(self): + __all__.extend(["baz"]) +"#).exports(), + @"", + ); + } + + #[test] + fn wildcard_reexport_simple_no_all() { + let test = PublicTestBuilder::default() + .source("foo.py", "ZQZQZQ = 1") + .source("test.py", "from foo import *") + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"ZQZQZQ :: Constant", + ); + } + + #[test] + fn wildcard_reexport_single_underscore_no_all() { + let test = PublicTestBuilder::default() + .source("foo.py", "_ZQZQZQ = 1") + .source("test.py", "from foo import *") + .build(); + // Without `__all__` present, a wildcard import won't include + // names starting with an underscore at runtime. So `_ZQZQZQ` + // should not be present here. + // See: + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn wildcard_reexport_double_underscore_no_all() { + let test = PublicTestBuilder::default() + .source("foo.py", "__ZQZQZQ = 1") + .source("test.py", "from foo import *") + .build(); + // Without `__all__` present, a wildcard import won't include + // names starting with an underscore at runtime. So `__ZQZQZQ` + // should not be present here. + // See: + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn wildcard_reexport_normal_import_no_all() { + let test = PublicTestBuilder::default() + .source("foo.py", "import collections") + .source("test.py", "from foo import *") + .build(); + // We specifically test for the absence of `collections` + // here. That is, `from foo import *` will import + // `collections` at runtime, but we don't consider it part + // of the exported interface of `foo`. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn wildcard_reexport_redundant_import_no_all() { + let test = PublicTestBuilder::default() + .source("foo.py", "import collections as collections") + .source("test.py", "from foo import *") + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"collections :: Module", + ); + } + + #[test] + fn wildcard_reexport_normal_from_import_no_all() { + let test = PublicTestBuilder::default() + .source("foo.py", "from collections import defaultdict") + .source("test.py", "from foo import *") + .build(); + // We specifically test for the absence of `defaultdict` + // here. That is, `from foo import *` will import + // `defaultdict` at runtime, but we don't consider it part + // of the exported interface of `foo`. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn wildcard_reexport_redundant_from_import_no_all() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + "from collections import defaultdict as defaultdict", + ) + .source("test.py", "from foo import *") + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"defaultdict :: Variable", + ); + } + + #[test] + fn wildcard_reexport_all_simple() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['ZQZQZQ'] + ", + ) + .source("test.py", "from foo import *") + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"ZQZQZQ :: Constant", + ); + } + + #[test] + fn wildcard_reexport_all_simple_include_all() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['__all__', 'ZQZQZQ'] + ", + ) + .source("test.py", "from foo import *") + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + ZQZQZQ :: Constant + __all__ :: Variable + ", + ); + } + + #[test] + fn wildcard_reexport_all_empty() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = [] + ", + ) + .source("test.py", "from foo import *") + .build(); + // Nothing is exported because `__all__` is defined + // and also empty. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn wildcard_reexport_all_empty_not_applies_to_importer() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = [] + ", + ) + .source( + "test.py", + "from foo import * + TRICKSY = 1", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"TRICKSY :: Constant", + ); + } + + #[test] + fn wildcard_reexport_all_include_all_applies_to_importer() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['__all__'] + ", + ) + .source( + "test.py", + "from foo import * + TRICKSY = 1", + ) + .build(); + // TRICKSY should specifically be absent because + // `__all__` is defined in `test.py` (via a wildcard + // import) and does not itself include `TRICKSY`. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"__all__ :: Variable", + ); + } + + #[test] + fn wildcard_reexport_all_empty_then_added_to() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = [] + ", + ) + .source( + "test.py", + "from foo import * + TRICKSY = 1 + __all__.append('TRICKSY')", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"TRICKSY :: Constant", + ); + } + + #[test] + fn wildcard_reexport_all_include_all_then_added_to() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['__all__'] + ", + ) + .source( + "test.py", + "from foo import * + TRICKSY = 1 + __all__.append('TRICKSY')", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + __all__ :: Variable + TRICKSY :: Constant + ", + ); + } + + /// Tests that a `from module import *` doesn't bring an + /// `__all__` into scope if `module` doesn't provide an + /// `__all__` that includes `__all__` AND this causes + /// `__all__.append` to fail in the importing module + /// (because it isn't defined). + #[test] + fn wildcard_reexport_all_empty_then_added_to_incorrect() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = [] + ", + ) + .source( + "test.py", + "from foo import * + from collections import defaultdict + __all__.append('defaultdict')", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn wildcard_reexport_all_include_all_then_added_to_correct() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['__all__'] + ", + ) + .source( + "test.py", + "from foo import * + from collections import defaultdict + __all__.append('defaultdict')", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + __all__ :: Variable + defaultdict :: Variable + ", + ); + } + + #[test] + fn wildcard_reexport_all_non_empty_but_non_existent() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['TRICKSY'] + ", + ) + .source("test.py", "from foo import *") + .build(); + // `TRICKSY` isn't actually a valid symbol, + // and `ZQZQZQ` isn't in `__all__`, so we get + // no symbols here. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn wildcard_reexport_all_include_all_and_non_existent() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['__all__', 'TRICKSY'] + ", + ) + .source("test.py", "from foo import *") + .build(); + // Note that this example will actually result in a runtime + // error since `TRICKSY` doesn't exist in `foo.py` and + // `from foo import *` will try to import it anyway. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"__all__ :: Variable", + ); + } + + #[test] + fn wildcard_reexport_all_not_applies_to_importer() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['TRICKSY'] + ", + ) + .source( + "test.py", + "from foo import * + TRICKSY = 1", + ) + .build(); + // Note that this example will actually result in a runtime + // error since `TRICKSY` doesn't exist in `foo.py` and + // `from foo import *` will try to import it anyway. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"TRICKSY :: Constant", + ); + } + + #[test] + fn wildcard_reexport_all_include_all_with_others_applies_to_importer() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['__all__', 'TRICKSY'] + ", + ) + .source( + "test.py", + "from foo import * + TRICKSY = 1", + ) + .build(); + // Note that this example will actually result in a runtime + // error since `TRICKSY` doesn't exist in `foo.py` and + // `from foo import *` will try to import it anyway. + insta::assert_snapshot!( + test.exports_for("test.py"), + @r" + __all__ :: Variable + TRICKSY :: Constant + ", + ); + } + + #[test] + fn explicit_reexport_all_empty_applies_to_importer() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = [] + ", + ) + .source( + "test.py", + "from foo import __all__ as __all__ + TRICKSY = 1", + ) + .build(); + // `__all__` is imported from `foo.py` but it's + // empty, so `TRICKSY` is not part of the exported + // API. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn explicit_reexport_all_empty_then_added_to() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = [] + ", + ) + .source( + "test.py", + "from foo import __all__ + TRICKSY = 1 + __all__.append('TRICKSY')", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"TRICKSY :: Constant", + ); + } + + #[test] + fn explicit_reexport_all_non_empty_but_non_existent() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['TRICKSY'] + ", + ) + .source("test.py", "from foo import __all__ as __all__") + .build(); + // `TRICKSY` is not a valid symbol, so it's not considered + // part of the exports of `test`. + insta::assert_snapshot!( + test.exports_for("test.py"), + @"", + ); + } + + #[test] + fn explicit_reexport_all_applies_to_importer() { + let test = PublicTestBuilder::default() + .source( + "foo.py", + " + ZQZQZQ = 1 + __all__ = ['TRICKSY'] + ", + ) + .source( + "test.py", + "from foo import __all__ + TRICKSY = 1", + ) + .build(); + insta::assert_snapshot!( + test.exports_for("test.py"), + @"TRICKSY :: Constant", + ); + } + fn matches(query: &str, symbol: &str) -> bool { super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol) } diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import.snap b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import.snap index cb2f8c55e3..2bdac4a17e 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import.snap @@ -6,7 +6,7 @@ expression: completions { "label": "Literal (import typing)", "kind": 6, - "sortText": " 50", + "sortText": " 58", "insertText": "Literal", "additionalTextEdits": [ { @@ -27,7 +27,7 @@ expression: completions { "label": "LiteralString (import typing)", "kind": 6, - "sortText": " 51", + "sortText": " 59", "insertText": "LiteralString", "additionalTextEdits": [ { diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_docstring.snap b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_docstring.snap index cb2f8c55e3..2bdac4a17e 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_docstring.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_docstring.snap @@ -6,7 +6,7 @@ expression: completions { "label": "Literal (import typing)", "kind": 6, - "sortText": " 50", + "sortText": " 58", "insertText": "Literal", "additionalTextEdits": [ { @@ -27,7 +27,7 @@ expression: completions { "label": "LiteralString (import typing)", "kind": 6, - "sortText": " 51", + "sortText": " 59", "insertText": "LiteralString", "additionalTextEdits": [ { diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_from_future.snap b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_from_future.snap index cb2f8c55e3..2bdac4a17e 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_from_future.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_from_future.snap @@ -6,7 +6,7 @@ expression: completions { "label": "Literal (import typing)", "kind": 6, - "sortText": " 50", + "sortText": " 58", "insertText": "Literal", "additionalTextEdits": [ { @@ -27,7 +27,7 @@ expression: completions { "label": "LiteralString (import typing)", "kind": 6, - "sortText": " 51", + "sortText": " 59", "insertText": "LiteralString", "additionalTextEdits": [ { diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_same_cell.snap b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_same_cell.snap index b7d8c9907a..a0ff0b77b6 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_same_cell.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__notebook__auto_import_same_cell.snap @@ -6,7 +6,7 @@ expression: completions { "label": "Literal (import typing)", "kind": 6, - "sortText": " 50", + "sortText": " 58", "insertText": "Literal", "additionalTextEdits": [ { @@ -27,7 +27,7 @@ expression: completions { "label": "LiteralString (import typing)", "kind": 6, - "sortText": " 51", + "sortText": " 59", "insertText": "LiteralString", "additionalTextEdits": [ {