[ty] Attach origin module on to re-exported symbols

This information should let us filter out (or rather, merge)
re-exported symbols across a package hierarchy for the purposes
of auto-completions.
This commit is contained in:
Andrew Gallant
2026-01-13 12:09:06 -05:00
committed by Andrew Gallant
parent 2cbd68ab70
commit db9eee7b06
2 changed files with 179 additions and 77 deletions

View File

@@ -5,7 +5,7 @@ use ty_project::Db;
use crate::{
SymbolKind,
symbols::{QueryPattern, SymbolInfo, symbols_for_file_global_only},
symbols::{ImportedFrom, QueryPattern, SymbolInfo, symbols_for_file_global_only},
};
/// Get all symbols matching the query string.
@@ -199,6 +199,16 @@ impl<'db> AllSymbolInfo<'db> {
pub fn file(&self) -> File {
self.file
}
/// Returns the module that this symbol was re-exported from.
///
/// This is only available for symbols that have been imported
/// into `Self::module()` *and* are determined to be re-exports.
pub(crate) fn imported_from(&self) -> Option<&ImportedFrom> {
self.symbol
.as_ref()
.and_then(|symbol| symbol.imported_from.as_ref())
}
}
#[cfg(test)]
@@ -213,7 +223,7 @@ mod tests {
};
#[test]
fn test_all_symbols_multi_file() {
fn all_symbols_multi_file() {
// We use odd symbol names here so that we can
// write queries that target them specifically
// and (hopefully) nothing else.

View File

@@ -258,6 +258,10 @@ pub struct SymbolInfo<'a> {
pub name: Cow<'a, str>,
/// The kind of symbol (function, class, variable, etc.)
pub kind: SymbolKind,
/// Whether this symbol was imported from another module.
///
/// And if so, this includes the name of that module.
pub imported_from: Option<ImportedFrom>,
/// The range of the symbol name
pub name_range: TextRange,
/// The full range of the symbol (including body)
@@ -269,6 +273,7 @@ impl SymbolInfo<'_> {
SymbolInfo {
name: Cow::Owned(self.name.to_string()),
kind: self.kind,
imported_from: self.imported_from.clone(),
name_range: self.name_range,
full_range: self.full_range,
}
@@ -280,6 +285,14 @@ impl<'a> From<&'a SymbolTree> for SymbolInfo<'a> {
SymbolInfo {
name: Cow::Borrowed(&symbol.name),
kind: symbol.kind,
// The clone here isn't great, but doing actual work here
// probably isn't the super common case. Namely, most
// imports aren't re-exports and get filtered out before
// we construct a `SymbolInfo`. This should only do actual
// work (like cloning a `ModuleName`) when this symbol
// is both imported from another module *and* determined
// to be a re-export. ---AG
imported_from: symbol.imported_from.clone(),
name_range: symbol.name_range,
full_range: symbol.full_range,
}
@@ -413,7 +426,34 @@ struct SymbolTree {
kind: SymbolKind,
name_range: TextRange,
full_range: TextRange,
import_kind: Option<ImportKind>,
imported_from: Option<ImportedFrom>,
}
#[derive(Debug, Clone, PartialEq, Eq, get_size2::GetSize)]
pub struct ImportedFrom {
module_name: ModuleName,
kind: ImportKind,
}
impl ImportedFrom {
fn import(alias: &ast::Alias, kind: ImportKind) -> Option<ImportedFrom> {
let module_name = ModuleName::new(&alias.name)?;
Some(ImportedFrom { module_name, kind })
}
fn import_from(
db: &dyn Db,
importing_file: File,
ast: &ast::StmtImportFrom,
kind: ImportKind,
) -> Option<ImportedFrom> {
let module_name = ModuleName::from_import_statement(db, importing_file, ast).ok()?;
Some(ImportedFrom { module_name, kind })
}
pub fn module_name(&self) -> &ModuleName {
&self.module_name
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, get_size2::GetSize)]
@@ -423,6 +463,16 @@ enum ImportKind {
Wildcard,
}
impl From<&ast::Alias> for ImportKind {
fn from(alias: &ast::Alias) -> ImportKind {
if alias.asname.as_ref().map(ast::Identifier::as_str) == Some(alias.name.as_str()) {
ImportKind::RedundantAlias
} else {
ImportKind::Normal
}
}
}
/// An abstraction for managing module scope imports.
///
/// This is meant to recognize the following idioms for updating
@@ -606,6 +656,21 @@ impl<'db> ImportModuleName<'db> {
}
}
#[derive(Clone, Copy, Debug)]
enum AstImport<'a> {
Import(&'a ast::StmtImport),
ImportFrom(&'a ast::StmtImportFrom),
}
impl Ranged for AstImport<'_> {
fn range(&self) -> TextRange {
match *self {
AstImport::Import(ast) => ast.range(),
AstImport::ImportFrom(ast) => ast.range(),
}
}
}
/// A visitor over all symbols in a single file.
///
/// This guarantees that child symbols have a symbol ID greater
@@ -775,36 +840,42 @@ impl<'db> SymbolVisitor<'db> {
kind,
name_range: name.range(),
full_range: stmt.range(),
import_kind: None,
imported_from: None,
};
self.add_symbol(symbol)
}
/// Adds a symbol introduced via an import `stmt`.
fn add_import_alias(&mut self, stmt: &ast::Stmt, alias: &ast::Alias) -> SymbolId {
fn add_import_alias(&mut self, import: AstImport<'_>, alias: &ast::Alias) -> Option<SymbolId> {
let name = alias.asname.as_ref().unwrap_or(&alias.name);
let kind = if stmt.is_import_stmt() {
let kind = if matches!(import, AstImport::Import(_)) {
SymbolKind::Module
} else if Self::is_constant_name(name.as_str()) {
SymbolKind::Constant
} else {
SymbolKind::Variable
};
let re_export = Some(
if alias.asname.as_ref().map(ast::Identifier::as_str) == Some(alias.name.as_str()) {
ImportKind::RedundantAlias
} else {
ImportKind::Normal
},
);
self.add_symbol(SymbolTree {
let import_kind = ImportKind::from(alias);
let full_range = import.range();
let Some(imported_from) = (match import {
AstImport::Import(_) => ImportedFrom::import(alias, import_kind),
AstImport::ImportFrom(ast) => {
ImportedFrom::import_from(self.db, self.file, ast, import_kind)
}
}) else {
tracing::debug!(
"Dropping imported symbol {name} since its module name could not be discovered",
);
return None;
};
Some(self.add_symbol(SymbolTree {
parent: None,
name: name.id.to_string(),
kind,
name_range: name.range(),
full_range: stmt.range(),
import_kind: re_export,
})
full_range,
imported_from: Some(imported_from),
}))
}
/// Extracts `__all__` names from the given assignment.
@@ -955,7 +1026,20 @@ impl<'db> SymbolVisitor<'db> {
return None;
}
let mut symbol = symbol.clone();
symbol.import_kind = Some(ImportKind::Wildcard);
let Some(imported_from) = ImportedFrom::import_from(
self.db,
self.file,
import_from,
ImportKind::Wildcard,
) else {
tracing::debug!(
"Dropping wildcard imported symbol {name} since \
its module name could not be discovered",
name = symbol.name,
);
return None;
};
symbol.imported_from = Some(imported_from);
Some(symbol)
}));
// If the imported module defines an `__all__` AND `__all__` is
@@ -1070,8 +1154,8 @@ impl<'db> SymbolVisitor<'db> {
// * `import X as X`
// * `from Y import X as X`
// * `from Y import *`
if let Some(kind) = symbol.import_kind {
return match kind {
if let Some(ref imported_from) = symbol.imported_from {
return match imported_from.kind {
ImportKind::RedundantAlias | ImportKind::Wildcard => true,
ImportKind::Normal => false,
};
@@ -1115,7 +1199,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> {
kind,
name_range: func_def.name.range(),
full_range: stmt.range(),
import_kind: None,
imported_from: None,
};
if self.exports_only {
@@ -1144,7 +1228,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> {
kind: SymbolKind::Class,
name_range: class_def.name.range(),
full_range: stmt.range(),
import_kind: None,
imported_from: None,
};
if self.exports_only {
@@ -1267,7 +1351,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> {
}
self.imports.add_import(import);
for alias in &import.names {
self.add_import_alias(stmt, alias);
self.add_import_alias(AstImport::Import(import), alias);
}
}
ast::Stmt::ImportFrom(import_from) => {
@@ -1294,7 +1378,7 @@ impl<'db> SourceOrderVisitor<'db> for SymbolVisitor<'db> {
{
self.add_all_from_import(import_from);
}
self.add_import_alias(stmt, alias);
self.add_import_alias(AstImport::ImportFrom(import_from), alias);
}
}
}
@@ -1566,7 +1650,7 @@ from collections import defaultdict as dd
public_test("\
import numpy as numpy
").exports(),
@"numpy :: Module",
@"numpy :: Module :: Re-exported from `numpy`",
);
}
@@ -1576,7 +1660,7 @@ import numpy as numpy
public_test("\
from collections import defaultdict as defaultdict
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
}
@@ -1587,7 +1671,7 @@ from collections import defaultdict as defaultdict
from collections import defaultdict
__all__ = ['defaultdict']
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
insta::assert_snapshot!(
@@ -1595,7 +1679,7 @@ __all__ = ['defaultdict']
from collections import defaultdict
__all__ = ('defaultdict',)
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
}
@@ -1606,7 +1690,7 @@ __all__ = ('defaultdict',)
from collections import defaultdict
__all__: list[str] = ['defaultdict']
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
insta::assert_snapshot!(
@@ -1614,7 +1698,7 @@ __all__: list[str] = ['defaultdict']
from collections import defaultdict
__all__: tuple[str, ...] = ('defaultdict',)
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
}
@@ -1626,7 +1710,7 @@ from collections import defaultdict
__all__ = []
__all__ += ['defaultdict']
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
insta::assert_snapshot!(
@@ -1635,7 +1719,7 @@ from collections import defaultdict
__all__ = []
__all__ += ('defaultdict',)
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
insta::assert_snapshot!(
@@ -1644,7 +1728,7 @@ from collections import defaultdict
__all__ = []
__all__ += {'defaultdict'}
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
}
@@ -1667,7 +1751,7 @@ from collections import defaultdict
__all__ = []
__all__.extend(['defaultdict'])
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
}
@@ -1690,7 +1774,7 @@ from collections import defaultdict
__all__ = []
__all__.append('defaultdict')
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
}
@@ -1702,7 +1786,7 @@ from collections import defaultdict
__all__ = []
__all__ += ['defaultdict']
").exports(),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `collections`",
);
}
@@ -1763,7 +1847,7 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"ZQZQZQ :: Constant",
@"ZQZQZQ :: Constant :: Re-exported from `foo`",
);
}
@@ -1823,7 +1907,7 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"collections :: Module",
@"collections :: Module :: Re-exported from `foo`",
);
}
@@ -1854,7 +1938,7 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"defaultdict :: Variable",
@"defaultdict :: Variable :: Re-exported from `foo`",
);
}
@@ -1872,7 +1956,7 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"ZQZQZQ :: Constant",
@"ZQZQZQ :: Constant :: Re-exported from `foo`",
);
}
@@ -1890,9 +1974,9 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
ZQZQZQ :: Constant
__all__ :: Variable
@r"
ZQZQZQ :: Constant :: Re-exported from `foo`
__all__ :: Variable :: Re-exported from `foo`
",
);
}
@@ -1960,7 +2044,7 @@ class X:
// import) and does not itself include `TRICKSY`.
insta::assert_snapshot!(
test.exports_for("test.py"),
@"__all__ :: Variable",
@"__all__ :: Variable :: Re-exported from `foo`",
);
}
@@ -2006,8 +2090,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
__all__ :: Variable
@r"
__all__ :: Variable :: Re-exported from `foo`
TRICKSY :: Constant
",
);
@@ -2060,9 +2144,9 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
__all__ :: Variable
defaultdict :: Variable
@r"
__all__ :: Variable :: Re-exported from `foo`
defaultdict :: Variable :: Re-exported from `collections`
",
);
}
@@ -2105,7 +2189,7 @@ class X:
// `from foo import *` will try to import it anyway.
insta::assert_snapshot!(
test.exports_for("test.py"),
@"__all__ :: Variable",
@"__all__ :: Variable :: Re-exported from `foo`",
);
}
@@ -2155,8 +2239,8 @@ class X:
// `from foo import *` will try to import it anyway.
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
__all__ :: Variable
@r"
__all__ :: Variable :: Re-exported from `foo`
TRICKSY :: Constant
",
);
@@ -2274,8 +2358,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `foo`
_ZYZYZY :: Constant
",
);
@@ -2303,8 +2387,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `foo`
_ZYZYZY :: Constant
",
);
@@ -2332,8 +2416,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `foo`
_ZYZYZY :: Constant
",
);
@@ -2362,8 +2446,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `parent.foo`
_ZYZYZY :: Constant
",
);
@@ -2392,8 +2476,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `parent.foo`
_ZYZYZY :: Constant
",
);
@@ -2422,8 +2506,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `parent.foo`
_ZYZYZY :: Constant
",
);
@@ -2452,8 +2536,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `parent.foo`
_ZYZYZY :: Constant
",
);
@@ -2482,8 +2566,8 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZQZQZQ :: Constant
@r"
_ZQZQZQ :: Constant :: Re-exported from `parent.foo`
_ZYZYZY :: Constant
",
);
@@ -2514,9 +2598,9 @@ class X:
.build();
insta::assert_snapshot!(
test.exports_for("a.py"),
@"
_ZBZBZB :: Constant
_ZAZAZA :: Constant
@r"
_ZBZBZB :: Constant :: Re-exported from `b`
_ZAZAZA :: Constant :: Re-exported from `b`
",
);
}
@@ -2559,8 +2643,8 @@ class X:
// `_ZBZBZB` instead of `_ZFZFZF`.
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZFZFZF :: Constant
@r"
_ZFZFZF :: Constant :: Re-exported from `foo`
_ZYZYZY :: Constant
",
);
@@ -2600,9 +2684,9 @@ class X:
// answer should just be `_ZFZFZF` and `_ZYZYZY`.
insta::assert_snapshot!(
test.exports_for("test.py"),
@"
_ZFZFZF :: Constant
foo :: Module
@r"
_ZFZFZF :: Constant :: Re-exported from `parent.foo`
foo :: Module :: Re-exported from `parent`
_ZYZYZY :: Constant
__all__ :: Variable
",
@@ -2640,7 +2724,15 @@ class X:
symbols
.iter()
.map(|(_, symbol)| {
format!("{name} :: {kind:?}", name = symbol.name, kind = symbol.kind)
let mut snapshot =
format!("{name} :: {kind:?}", name = symbol.name, kind = symbol.kind,);
if let Some(ref imported_from) = symbol.imported_from {
snapshot = format!(
"{snapshot} :: Re-exported from `{module_name}`",
module_name = imported_from.module_name()
);
}
snapshot
})
.collect::<Vec<String>>()
.join("\n")