From db9eee7b06ae7b220b1425cdb2db0641ca7d1705 Mon Sep 17 00:00:00 2001 From: Andrew Gallant Date: Tue, 13 Jan 2026 12:09:06 -0500 Subject: [PATCH] [ty] Attach origin module on to re-exported symbols This information should let us filter out (or rather, merge) re-exported symbols across a package hierarchy for the purposes of auto-completions. --- crates/ty_ide/src/all_symbols.rs | 14 +- crates/ty_ide/src/symbols.rs | 242 +++++++++++++++++++++---------- 2 files changed, 179 insertions(+), 77 deletions(-) diff --git a/crates/ty_ide/src/all_symbols.rs b/crates/ty_ide/src/all_symbols.rs index d3c9a0d59a..4f0e4da23c 100644 --- a/crates/ty_ide/src/all_symbols.rs +++ b/crates/ty_ide/src/all_symbols.rs @@ -5,7 +5,7 @@ use ty_project::Db; use crate::{ SymbolKind, - symbols::{QueryPattern, SymbolInfo, symbols_for_file_global_only}, + symbols::{ImportedFrom, QueryPattern, SymbolInfo, symbols_for_file_global_only}, }; /// Get all symbols matching the query string. @@ -199,6 +199,16 @@ impl<'db> AllSymbolInfo<'db> { pub fn file(&self) -> File { self.file } + + /// Returns the module that this symbol was re-exported from. + /// + /// This is only available for symbols that have been imported + /// into `Self::module()` *and* are determined to be re-exports. + pub(crate) fn imported_from(&self) -> Option<&ImportedFrom> { + self.symbol + .as_ref() + .and_then(|symbol| symbol.imported_from.as_ref()) + } } #[cfg(test)] @@ -213,7 +223,7 @@ mod tests { }; #[test] - fn test_all_symbols_multi_file() { + fn all_symbols_multi_file() { // We use odd symbol names here so that we can // write queries that target them specifically // and (hopefully) nothing else. diff --git a/crates/ty_ide/src/symbols.rs b/crates/ty_ide/src/symbols.rs index 2bb0bda7fc..0d09fb1a7b 100644 --- a/crates/ty_ide/src/symbols.rs +++ b/crates/ty_ide/src/symbols.rs @@ -258,6 +258,10 @@ pub struct SymbolInfo<'a> { pub name: Cow<'a, str>, /// The kind of symbol (function, class, variable, etc.) pub kind: SymbolKind, + /// Whether this symbol was imported from another module. + /// + /// And if so, this includes the name of that module. + pub imported_from: Option, /// The range of the symbol name pub name_range: TextRange, /// The full range of the symbol (including body) @@ -269,6 +273,7 @@ impl SymbolInfo<'_> { SymbolInfo { name: Cow::Owned(self.name.to_string()), kind: self.kind, + imported_from: self.imported_from.clone(), name_range: self.name_range, full_range: self.full_range, } @@ -280,6 +285,14 @@ impl<'a> From<&'a SymbolTree> for SymbolInfo<'a> { SymbolInfo { name: Cow::Borrowed(&symbol.name), kind: symbol.kind, + // The clone here isn't great, but doing actual work here + // probably isn't the super common case. Namely, most + // imports aren't re-exports and get filtered out before + // we construct a `SymbolInfo`. This should only do actual + // work (like cloning a `ModuleName`) when this symbol + // is both imported from another module *and* determined + // to be a re-export. ---AG + imported_from: symbol.imported_from.clone(), name_range: symbol.name_range, full_range: symbol.full_range, } @@ -413,7 +426,34 @@ struct SymbolTree { kind: SymbolKind, name_range: TextRange, full_range: TextRange, - import_kind: Option, + imported_from: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, get_size2::GetSize)] +pub struct ImportedFrom { + module_name: ModuleName, + kind: ImportKind, +} + +impl ImportedFrom { + fn import(alias: &ast::Alias, kind: ImportKind) -> Option { + let module_name = ModuleName::new(&alias.name)?; + Some(ImportedFrom { module_name, kind }) + } + + fn import_from( + db: &dyn Db, + importing_file: File, + ast: &ast::StmtImportFrom, + kind: ImportKind, + ) -> Option { + let module_name = ModuleName::from_import_statement(db, importing_file, ast).ok()?; + Some(ImportedFrom { module_name, kind }) + } + + pub fn module_name(&self) -> &ModuleName { + &self.module_name + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, get_size2::GetSize)] @@ -423,6 +463,16 @@ enum ImportKind { Wildcard, } +impl From<&ast::Alias> for ImportKind { + fn from(alias: &ast::Alias) -> ImportKind { + if alias.asname.as_ref().map(ast::Identifier::as_str) == Some(alias.name.as_str()) { + ImportKind::RedundantAlias + } else { + ImportKind::Normal + } + } +} + /// An abstraction for managing module scope imports. /// /// This is meant to recognize the following idioms for updating @@ -606,6 +656,21 @@ impl<'db> ImportModuleName<'db> { } } +#[derive(Clone, Copy, Debug)] +enum AstImport<'a> { + Import(&'a ast::StmtImport), + ImportFrom(&'a ast::StmtImportFrom), +} + +impl Ranged for AstImport<'_> { + fn range(&self) -> TextRange { + match *self { + AstImport::Import(ast) => ast.range(), + AstImport::ImportFrom(ast) => ast.range(), + } + } +} + /// A visitor over all symbols in a single file. /// /// This guarantees that child symbols have a symbol ID greater @@ -775,36 +840,42 @@ impl<'db> SymbolVisitor<'db> { kind, name_range: name.range(), full_range: stmt.range(), - import_kind: None, + imported_from: None, }; self.add_symbol(symbol) } /// Adds a symbol introduced via an import `stmt`. - fn add_import_alias(&mut self, stmt: &ast::Stmt, alias: &ast::Alias) -> SymbolId { + fn add_import_alias(&mut self, import: AstImport<'_>, alias: &ast::Alias) -> Option { let name = alias.asname.as_ref().unwrap_or(&alias.name); - let kind = if stmt.is_import_stmt() { + let kind = if matches!(import, AstImport::Import(_)) { SymbolKind::Module } else if Self::is_constant_name(name.as_str()) { SymbolKind::Constant } else { SymbolKind::Variable }; - let re_export = Some( - if alias.asname.as_ref().map(ast::Identifier::as_str) == Some(alias.name.as_str()) { - ImportKind::RedundantAlias - } else { - ImportKind::Normal - }, - ); - self.add_symbol(SymbolTree { + let import_kind = ImportKind::from(alias); + let full_range = import.range(); + let Some(imported_from) = (match import { + AstImport::Import(_) => ImportedFrom::import(alias, import_kind), + AstImport::ImportFrom(ast) => { + ImportedFrom::import_from(self.db, self.file, ast, import_kind) + } + }) else { + tracing::debug!( + "Dropping imported symbol {name} since its module name could not be discovered", + ); + return None; + }; + Some(self.add_symbol(SymbolTree { parent: None, name: name.id.to_string(), kind, name_range: name.range(), - full_range: stmt.range(), - import_kind: re_export, - }) + full_range, + imported_from: Some(imported_from), + })) } /// Extracts `__all__` names from the given assignment. @@ -955,7 +1026,20 @@ impl<'db> SymbolVisitor<'db> { return None; } let mut symbol = symbol.clone(); - symbol.import_kind = Some(ImportKind::Wildcard); + let Some(imported_from) = ImportedFrom::import_from( + self.db, + self.file, + import_from, + ImportKind::Wildcard, + ) else { + tracing::debug!( + "Dropping wildcard imported symbol {name} since \ + its module name could not be discovered", + name = symbol.name, + ); + return None; + }; + symbol.imported_from = Some(imported_from); Some(symbol) })); // If the imported module defines an `__all__` AND `__all__` is @@ -1070,8 +1154,8 @@ impl<'db> SymbolVisitor<'db> { // * `import X as X` // * `from Y import X as X` // * `from Y import *` - if let Some(kind) = symbol.import_kind { - return match kind { + if let Some(ref imported_from) = symbol.imported_from { + return match imported_from.kind { ImportKind::RedundantAlias | ImportKind::Wildcard => true, ImportKind::Normal => false, }; @@ -1115,7 +1199,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> { kind, name_range: func_def.name.range(), full_range: stmt.range(), - import_kind: None, + imported_from: None, }; if self.exports_only { @@ -1144,7 +1228,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> { kind: SymbolKind::Class, name_range: class_def.name.range(), full_range: stmt.range(), - import_kind: None, + imported_from: None, }; if self.exports_only { @@ -1267,7 +1351,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> { } self.imports.add_import(import); for alias in &import.names { - self.add_import_alias(stmt, alias); + self.add_import_alias(AstImport::Import(import), alias); } } ast::Stmt::ImportFrom(import_from) => { @@ -1294,7 +1378,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> { { self.add_all_from_import(import_from); } - self.add_import_alias(stmt, alias); + self.add_import_alias(AstImport::ImportFrom(import_from), alias); } } } @@ -1566,7 +1650,7 @@ from collections import defaultdict as dd public_test("\ import numpy as numpy ").exports(), - @"numpy :: Module", + @"numpy :: Module :: Re-exported from `numpy`", ); } @@ -1576,7 +1660,7 @@ import numpy as numpy public_test("\ from collections import defaultdict as defaultdict ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); } @@ -1587,7 +1671,7 @@ from collections import defaultdict as defaultdict from collections import defaultdict __all__ = ['defaultdict'] ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); insta::assert_snapshot!( @@ -1595,7 +1679,7 @@ __all__ = ['defaultdict'] from collections import defaultdict __all__ = ('defaultdict',) ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); } @@ -1606,7 +1690,7 @@ __all__ = ('defaultdict',) from collections import defaultdict __all__: list[str] = ['defaultdict'] ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); insta::assert_snapshot!( @@ -1614,7 +1698,7 @@ __all__: list[str] = ['defaultdict'] from collections import defaultdict __all__: tuple[str, ...] = ('defaultdict',) ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); } @@ -1626,7 +1710,7 @@ from collections import defaultdict __all__ = [] __all__ += ['defaultdict'] ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); insta::assert_snapshot!( @@ -1635,7 +1719,7 @@ from collections import defaultdict __all__ = [] __all__ += ('defaultdict',) ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); insta::assert_snapshot!( @@ -1644,7 +1728,7 @@ from collections import defaultdict __all__ = [] __all__ += {'defaultdict'} ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); } @@ -1667,7 +1751,7 @@ from collections import defaultdict __all__ = [] __all__.extend(['defaultdict']) ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); } @@ -1690,7 +1774,7 @@ from collections import defaultdict __all__ = [] __all__.append('defaultdict') ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); } @@ -1702,7 +1786,7 @@ from collections import defaultdict __all__ = [] __all__ += ['defaultdict'] ").exports(), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `collections`", ); } @@ -1763,7 +1847,7 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @"ZQZQZQ :: Constant", + @"ZQZQZQ :: Constant :: Re-exported from `foo`", ); } @@ -1823,7 +1907,7 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @"collections :: Module", + @"collections :: Module :: Re-exported from `foo`", ); } @@ -1854,7 +1938,7 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @"defaultdict :: Variable", + @"defaultdict :: Variable :: Re-exported from `foo`", ); } @@ -1872,7 +1956,7 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @"ZQZQZQ :: Constant", + @"ZQZQZQ :: Constant :: Re-exported from `foo`", ); } @@ -1890,9 +1974,9 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - ZQZQZQ :: Constant - __all__ :: Variable + @r" + ZQZQZQ :: Constant :: Re-exported from `foo` + __all__ :: Variable :: Re-exported from `foo` ", ); } @@ -1960,7 +2044,7 @@ class X: // import) and does not itself include `TRICKSY`. insta::assert_snapshot!( test.exports_for("test.py"), - @"__all__ :: Variable", + @"__all__ :: Variable :: Re-exported from `foo`", ); } @@ -2006,8 +2090,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - __all__ :: Variable + @r" + __all__ :: Variable :: Re-exported from `foo` TRICKSY :: Constant ", ); @@ -2060,9 +2144,9 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - __all__ :: Variable - defaultdict :: Variable + @r" + __all__ :: Variable :: Re-exported from `foo` + defaultdict :: Variable :: Re-exported from `collections` ", ); } @@ -2105,7 +2189,7 @@ class X: // `from foo import *` will try to import it anyway. insta::assert_snapshot!( test.exports_for("test.py"), - @"__all__ :: Variable", + @"__all__ :: Variable :: Re-exported from `foo`", ); } @@ -2155,8 +2239,8 @@ class X: // `from foo import *` will try to import it anyway. insta::assert_snapshot!( test.exports_for("test.py"), - @" - __all__ :: Variable + @r" + __all__ :: Variable :: Re-exported from `foo` TRICKSY :: Constant ", ); @@ -2274,8 +2358,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `foo` _ZYZYZY :: Constant ", ); @@ -2303,8 +2387,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `foo` _ZYZYZY :: Constant ", ); @@ -2332,8 +2416,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `foo` _ZYZYZY :: Constant ", ); @@ -2362,8 +2446,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `parent.foo` _ZYZYZY :: Constant ", ); @@ -2392,8 +2476,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `parent.foo` _ZYZYZY :: Constant ", ); @@ -2422,8 +2506,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `parent.foo` _ZYZYZY :: Constant ", ); @@ -2452,8 +2536,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `parent.foo` _ZYZYZY :: Constant ", ); @@ -2482,8 +2566,8 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZQZQZQ :: Constant + @r" + _ZQZQZQ :: Constant :: Re-exported from `parent.foo` _ZYZYZY :: Constant ", ); @@ -2514,9 +2598,9 @@ class X: .build(); insta::assert_snapshot!( test.exports_for("a.py"), - @" - _ZBZBZB :: Constant - _ZAZAZA :: Constant + @r" + _ZBZBZB :: Constant :: Re-exported from `b` + _ZAZAZA :: Constant :: Re-exported from `b` ", ); } @@ -2559,8 +2643,8 @@ class X: // `_ZBZBZB` instead of `_ZFZFZF`. insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZFZFZF :: Constant + @r" + _ZFZFZF :: Constant :: Re-exported from `foo` _ZYZYZY :: Constant ", ); @@ -2600,9 +2684,9 @@ class X: // answer should just be `_ZFZFZF` and `_ZYZYZY`. insta::assert_snapshot!( test.exports_for("test.py"), - @" - _ZFZFZF :: Constant - foo :: Module + @r" + _ZFZFZF :: Constant :: Re-exported from `parent.foo` + foo :: Module :: Re-exported from `parent` _ZYZYZY :: Constant __all__ :: Variable ", @@ -2640,7 +2724,15 @@ class X: symbols .iter() .map(|(_, symbol)| { - format!("{name} :: {kind:?}", name = symbol.name, kind = symbol.kind) + let mut snapshot = + format!("{name} :: {kind:?}", name = symbol.name, kind = symbol.kind,); + if let Some(ref imported_from) = symbol.imported_from { + snapshot = format!( + "{snapshot} :: Re-exported from `{module_name}`", + module_name = imported_from.module_name() + ); + } + snapshot }) .collect::>() .join("\n")