[ty] Make auto import completions include global imports from other modules in suggestions

Previously, our special auto-import code for discovering symbols in
other files quickly (without running ty on them) didn't take imports
into account. This PR makes a small change to do exactly that.

This in particular helps with libraries that build their public API
from submodules. In particular, numpy. This consequently improves our
numpy evaluation tasks (which includes sub-optimal ranking because of
precisely this bug). The ranking still isn't perfect, but at least the
correct result appears in the suggestions. It previously did not.

Unfortunately, this does regress some other tasks. For example, invoking
auto-import on `TypeVa<CURSOR>` now brings up `TypeVar` from a whole
bunch of modules. Presumably because it's imported in those modules.

So I guess that means this heuristic is probably wrong. How does one
differentiate imports that are meant to build out an API and just a
regular old import meant for internal use?

One idea is to perhaps down-rank symbols derived from imports, but above
symbols from private modules (beginning with `_`). This would work for
the numpy case I believe without (hopefully) regressing other tasks.
This commit is contained in:
Andrew Gallant 2025-10-28 13:59:41 -04:00
parent 39f43d888d
commit 2066d35038
No known key found for this signature in database
GPG Key ID: 5518C8B38E0693E0
2 changed files with 174 additions and 2 deletions

View File

@ -4058,6 +4058,85 @@ def f[T](x: T):
test.build().contains("__repr__"); test.build().contains("__repr__");
} }
#[test]
fn reexport_simple_import_noauto() {
let snapshot = CursorTest::builder()
.source(
"main.py",
r#"
import foo
foo.ZQ<CURSOR>
"#,
)
.source("foo.py", r#"from bar import ZQZQ"#)
.source("bar.py", r#"ZQZQ = 1"#)
.completion_test_builder()
.module_names()
.build()
.snapshot();
assert_snapshot!(snapshot, @"ZQZQ :: Current module");
}
#[test]
fn reexport_simple_import_auto() {
let snapshot = CursorTest::builder()
.source(
"main.py",
r#"
ZQ<CURSOR>
"#,
)
.source("foo.py", r#"from bar import ZQZQ"#)
.source("bar.py", r#"ZQZQ = 1"#)
.completion_test_builder()
.auto_import()
.module_names()
.build()
.snapshot();
assert_snapshot!(snapshot, @"ZQZQ :: bar");
}
#[test]
fn reexport_redundant_convention_import_noauto() {
let snapshot = CursorTest::builder()
.source(
"main.py",
r#"
import foo
foo.ZQ<CURSOR>
"#,
)
.source("foo.py", r#"from bar import ZQZQ as ZQZQ"#)
.source("bar.py", r#"ZQZQ = 1"#)
.completion_test_builder()
.module_names()
.build()
.snapshot();
assert_snapshot!(snapshot, @"ZQZQ :: Current module");
}
#[test]
fn reexport_redundant_convention_import_auto() {
let snapshot = CursorTest::builder()
.source(
"main.py",
r#"
ZQ<CURSOR>
"#,
)
.source("foo.py", r#"from bar import ZQZQ as ZQZQ"#)
.source("bar.py", r#"ZQZQ = 1"#)
.completion_test_builder()
.auto_import()
.module_names()
.build()
.snapshot();
assert_snapshot!(snapshot, @r"
ZQZQ :: bar
ZQZQ :: foo
");
}
/// A way to create a simple single-file (named `main.py`) completion test /// A way to create a simple single-file (named `main.py`) completion test
/// builder. /// builder.
/// ///

View File

@ -334,7 +334,7 @@ pub(crate) fn symbols_for_file(db: &dyn Db, file: File) -> FlatSymbols {
}; };
visitor.visit_body(&module.syntax().body); visitor.visit_body(&module.syntax().body);
FlatSymbols { FlatSymbols {
symbols: visitor.symbols, symbols: visitor.into_symbols(),
} }
} }
@ -356,7 +356,7 @@ pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbo
}; };
visitor.visit_body(&module.syntax().body); visitor.visit_body(&module.syntax().body);
FlatSymbols { FlatSymbols {
symbols: visitor.symbols, symbols: visitor.into_symbols(),
} }
} }
@ -367,6 +367,13 @@ struct SymbolTree {
kind: SymbolKind, kind: SymbolKind,
name_range: TextRange, name_range: TextRange,
full_range: TextRange, full_range: TextRange,
re_export: Option<ReExport>,
}
#[derive(Debug, Clone, PartialEq, Eq, get_size2::GetSize)]
enum ReExport {
Normal,
RedundantAlias,
} }
/// A visitor over all symbols in a single file. /// A visitor over all symbols in a single file.
@ -382,6 +389,44 @@ struct SymbolVisitor {
} }
impl SymbolVisitor { impl SymbolVisitor {
fn into_symbols(self) -> IndexVec<SymbolId, SymbolTree> {
// We want to filter out some of the symbols we collected.
// But, we always assigned IDs to each symbol based on
// their position in a sequence. So when we filter some
// out, we need to remap the identifiers.
//
// N.B. This can be skipped when `global_only` is true,
// since in that case, none of the symbols have a parent
// ID by construction.
let mut remap = IndexVec::with_capacity(self.symbols.len());
let mut new = IndexVec::with_capacity(self.symbols.len());
for mut symbol in self.symbols {
if symbol.re_export == Some(ReExport::Normal) {
remap.push(None);
continue;
}
if let Some(ref mut parent) = symbol.parent {
// OK because the visitor guarantees that
// all parents have IDs less than their
// children. So its ID has already been
// remapped.
if let Some(new_parent) = remap[*parent] {
*parent = new_parent;
} else {
// The parent symbol was dropped, so
// all of its children should be as
// well.
remap.push(None);
continue;
}
}
let new_id = new.next_index();
remap.push(Some(new_id));
new.push(symbol);
}
new
}
fn visit_body(&mut self, body: &[ast::Stmt]) { fn visit_body(&mut self, body: &[ast::Stmt]) {
for stmt in body { for stmt in body {
self.visit_stmt(stmt); self.visit_stmt(stmt);
@ -401,6 +446,30 @@ impl SymbolVisitor {
symbol_id symbol_id
} }
fn add_import_alias(&mut self, stmt: &ast::Stmt, alias: &ast::Alias) -> SymbolId {
let name = alias.asname.as_ref().unwrap_or(&alias.name);
let kind = if Self::is_constant_name(name.as_str()) {
SymbolKind::Constant
} else {
SymbolKind::Variable
};
let re_export = Some(
if alias.asname.as_ref().map(ast::Identifier::as_str) == Some(alias.name.as_str()) {
ReExport::RedundantAlias
} else {
ReExport::Normal
},
);
self.add_symbol(SymbolTree {
parent: None,
name: name.id.to_string(),
kind,
name_range: name.range(),
full_range: stmt.range(),
re_export,
})
}
fn push_symbol(&mut self, symbol: SymbolTree) { fn push_symbol(&mut self, symbol: SymbolTree) {
let symbol_id = self.add_symbol(symbol); let symbol_id = self.add_symbol(symbol);
self.symbol_stack.push(symbol_id); self.symbol_stack.push(symbol_id);
@ -445,6 +514,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
kind, kind,
name_range: func_def.name.range(), name_range: func_def.name.range(),
full_range: stmt.range(), full_range: stmt.range(),
re_export: None,
}; };
if self.global_only { if self.global_only {
@ -474,6 +544,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
kind: SymbolKind::Class, kind: SymbolKind::Class,
name_range: class_def.name.range(), name_range: class_def.name.range(),
full_range: stmt.range(), full_range: stmt.range(),
re_export: None,
}; };
if self.global_only { if self.global_only {
@ -513,6 +584,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
kind, kind,
name_range: name.range(), name_range: name.range(),
full_range: stmt.range(), full_range: stmt.range(),
re_export: None,
}; };
self.add_symbol(symbol); self.add_symbol(symbol);
} }
@ -543,10 +615,31 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
kind, kind,
name_range: name.range(), name_range: name.range(),
full_range: stmt.range(), full_range: stmt.range(),
re_export: None,
}; };
self.add_symbol(symbol); self.add_symbol(symbol);
} }
ast::Stmt::Import(import) => {
// We only consider imports in global scope.
if self.in_function {
return;
}
for alias in &import.names {
self.add_import_alias(stmt, alias);
}
}
ast::Stmt::ImportFrom(import_from) => {
// We only consider imports in global scope.
if self.in_function {
return;
}
for alias in &import_from.names {
self.add_import_alias(stmt, alias);
}
}
_ => { _ => {
source_order::walk_stmt(self, stmt); source_order::walk_stmt(self, stmt);
} }