[ty] Simplify auto-import AST visitor slightly and add tests

This simplifies the existing visitor by DRYing it up slightly.
We also add tests for the existing functionality. In particular,
we want to add support for re-export conventions, and that
warrants more careful testing.
This commit is contained in:
Andrew Gallant 2025-11-25 14:02:34 -05:00 committed by Andrew Gallant
parent 62f20b1e86
commit 5da45f8ec7
1 changed files with 221 additions and 45 deletions

View File

@ -436,6 +436,28 @@ impl SymbolVisitor {
symbol_id symbol_id
} }
fn add_assignment(&mut self, stmt: &ast::Stmt, name: &ast::ExprName) -> SymbolId {
let kind = if Self::is_constant_name(name.id.as_str()) {
SymbolKind::Constant
} else if self
.iter_symbol_stack()
.any(|s| s.kind == SymbolKind::Class)
{
SymbolKind::Field
} else {
SymbolKind::Variable
};
let symbol = SymbolTree {
parent: None,
name: name.id.to_string(),
kind,
name_range: name.range(),
full_range: stmt.range(),
};
self.add_symbol(symbol)
}
fn push_symbol(&mut self, symbol: SymbolTree) { fn push_symbol(&mut self, symbol: SymbolTree) {
let symbol_id = self.add_symbol(symbol); let symbol_id = self.add_symbol(symbol);
self.symbol_stack.push(symbol_id); self.symbol_stack.push(symbol_id);
@ -501,7 +523,6 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
self.pop_symbol(); self.pop_symbol();
} }
ast::Stmt::ClassDef(class_def) => { ast::Stmt::ClassDef(class_def) => {
let symbol = SymbolTree { let symbol = SymbolTree {
parent: None, parent: None,
@ -521,7 +542,6 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
source_order::walk_stmt(self, stmt); source_order::walk_stmt(self, stmt);
self.pop_symbol(); self.pop_symbol();
} }
ast::Stmt::Assign(assign) => { ast::Stmt::Assign(assign) => {
// Include assignments only when we're in global or class scope // Include assignments only when we're in global or class scope
if self.in_function { if self.in_function {
@ -531,28 +551,9 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
let ast::Expr::Name(name) = target else { let ast::Expr::Name(name) = target else {
continue; continue;
}; };
let kind = if Self::is_constant_name(name.id.as_str()) { self.add_assignment(stmt, name);
SymbolKind::Constant
} else if self
.iter_symbol_stack()
.any(|s| s.kind == SymbolKind::Class)
{
SymbolKind::Field
} else {
SymbolKind::Variable
};
let symbol = SymbolTree {
parent: None,
name: name.id.to_string(),
kind,
name_range: name.range(),
full_range: stmt.range(),
};
self.add_symbol(symbol);
} }
} }
ast::Stmt::AnnAssign(ann_assign) => { ast::Stmt::AnnAssign(ann_assign) => {
// Include assignments only when we're in global or class scope // Include assignments only when we're in global or class scope
if self.in_function { if self.in_function {
@ -561,27 +562,8 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
let ast::Expr::Name(name) = &*ann_assign.target else { let ast::Expr::Name(name) = &*ann_assign.target else {
return; return;
}; };
let kind = if Self::is_constant_name(name.id.as_str()) { self.add_assignment(stmt, name);
SymbolKind::Constant
} else if self
.iter_symbol_stack()
.any(|s| s.kind == SymbolKind::Class)
{
SymbolKind::Field
} else {
SymbolKind::Variable
};
let symbol = SymbolTree {
parent: None,
name: name.id.to_string(),
kind,
name_range: name.range(),
full_range: stmt.range(),
};
self.add_symbol(symbol);
} }
_ => { _ => {
source_order::walk_stmt(self, stmt); source_order::walk_stmt(self, stmt);
} }
@ -591,9 +573,16 @@ impl SourceOrderVisitor<'_> for SymbolVisitor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
fn matches(query: &str, symbol: &str) -> bool { use camino::Utf8Component;
super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol) use insta::internals::SettingsBindDropGuard;
}
use ruff_db::Db;
use ruff_db::files::{FileRootKind, system_path_to_file};
use ruff_db::system::{DbWithWritableSystem, SystemPath, SystemPathBuf};
use ruff_python_trivia::textwrap::dedent;
use ty_project::{ProjectMetadata, TestDb};
use super::symbols_for_file_global_only;
#[test] #[test]
fn various_yes() { fn various_yes() {
@ -625,4 +614,191 @@ mod tests {
assert!(!matches("abcd", "abc")); assert!(!matches("abcd", "abc"));
assert!(!matches("δΘπ", "θΔΠ")); assert!(!matches("δΘπ", "θΔΠ"));
} }
#[test]
fn exports_simple() {
insta::assert_snapshot!(
public_test("\
FOO = 1
foo = 1
frob: int = 1
class Foo:
BAR = 1
def quux():
baz = 1
").exports(),
@r"
FOO :: Constant
foo :: Variable
frob :: Variable
Foo :: Class
quux :: Function
",
);
}
#[test]
fn exports_conditional_true() {
insta::assert_snapshot!(
public_test("\
foo = 1
if True:
bar = 1
").exports(),
@r"
foo :: Variable
bar :: Variable
",
);
}
#[test]
fn exports_conditional_false() {
// FIXME: This shouldn't include `bar`.
insta::assert_snapshot!(
public_test("\
foo = 1
if False:
bar = 1
").exports(),
@r"
foo :: Variable
bar :: Variable
",
);
}
#[test]
fn exports_conditional_sys_version() {
// FIXME: This shouldn't include `bar`.
insta::assert_snapshot!(
public_test("\
import sys
foo = 1
if sys.version < (3, 5):
bar = 1
").exports(),
@r"
foo :: Variable
bar :: Variable
",
);
}
#[test]
fn exports_type_checking() {
insta::assert_snapshot!(
public_test("\
from typing import TYPE_CHECKING
foo = 1
if TYPE_CHECKING:
bar = 1
").exports(),
@r"
foo :: Variable
bar :: Variable
",
);
}
fn matches(query: &str, symbol: &str) -> bool {
super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol)
}
fn public_test(code: &str) -> PublicTest {
PublicTestBuilder::default().source("test.py", code).build()
}
struct PublicTest {
db: TestDb,
_insta_settings_guard: SettingsBindDropGuard,
}
impl PublicTest {
/// Returns the exports from `test.py`.
///
/// This is, conventionally, the default module file path used. For
/// example, it's used by the `public_test` convenience constructor.
fn exports(&self) -> String {
self.exports_for("test.py")
}
/// Returns the exports from the module at the given path.
///
/// The path given must have been written to this test's salsa DB.
fn exports_for(&self, path: impl AsRef<SystemPath>) -> String {
let file = system_path_to_file(&self.db, path.as_ref()).unwrap();
let symbols = symbols_for_file_global_only(&self.db, file);
symbols
.iter()
.map(|(_, symbol)| {
format!("{name} :: {kind:?}", name = symbol.name, kind = symbol.kind)
})
.collect::<Vec<String>>()
.join("\n")
}
}
#[derive(Default)]
struct PublicTestBuilder {
/// A list of source files, corresponding to the
/// file's path and its contents.
sources: Vec<Source>,
}
impl PublicTestBuilder {
pub(super) fn build(&self) -> PublicTest {
let mut db = TestDb::new(ProjectMetadata::new(
"test".into(),
SystemPathBuf::from("/"),
));
db.init_program().unwrap();
for Source { path, contents } in &self.sources {
db.write_file(path, contents)
.expect("write to memory file system to be successful");
// Add a root for the top-most component.
let top = path.components().find_map(|c| match c {
Utf8Component::Normal(c) => Some(c),
_ => None,
});
if let Some(top) = top {
let top = SystemPath::new(top);
if db.system().is_directory(top) {
db.files()
.try_add_root(&db, top, FileRootKind::LibrarySearchPath);
}
}
}
// N.B. We don't set anything custom yet, but we leave
// this here for when we invevitable add a filter.
let insta_settings = insta::Settings::clone_current();
let insta_settings_guard = insta_settings.bind_to_scope();
PublicTest {
db,
_insta_settings_guard: insta_settings_guard,
}
}
pub(super) fn source(
&mut self,
path: impl Into<SystemPathBuf>,
contents: impl AsRef<str>,
) -> &mut PublicTestBuilder {
let path = path.into();
let contents = dedent(contents.as_ref()).into_owned();
self.sources.push(Source { path, contents });
self
}
}
struct Source {
path: SystemPathBuf,
contents: String,
}
} }