From 5da45f8ec78f119ae58cd04b8634b6831a423efa Mon Sep 17 00:00:00 2001 From: Andrew Gallant Date: Tue, 25 Nov 2025 14:02:34 -0500 Subject: [PATCH] [ty] Simplify auto-import AST visitor slightly and add tests This simplifies the existing visitor by DRYing it up slightly. We also add tests for the existing functionality. In particular, we want to add support for re-export conventions, and that warrants more careful testing. --- crates/ty_ide/src/symbols.rs | 266 +++++++++++++++++++++++++++++------ 1 file changed, 221 insertions(+), 45 deletions(-) diff --git a/crates/ty_ide/src/symbols.rs b/crates/ty_ide/src/symbols.rs index ebb6637ae3..a49ac318fc 100644 --- a/crates/ty_ide/src/symbols.rs +++ b/crates/ty_ide/src/symbols.rs @@ -436,6 +436,28 @@ impl SymbolVisitor { symbol_id } + fn add_assignment(&mut self, stmt: &ast::Stmt, name: &ast::ExprName) -> SymbolId { + let kind = if Self::is_constant_name(name.id.as_str()) { + SymbolKind::Constant + } else if self + .iter_symbol_stack() + .any(|s| s.kind == SymbolKind::Class) + { + SymbolKind::Field + } else { + SymbolKind::Variable + }; + + let symbol = SymbolTree { + parent: None, + name: name.id.to_string(), + kind, + name_range: name.range(), + full_range: stmt.range(), + }; + self.add_symbol(symbol) + } + fn push_symbol(&mut self, symbol: SymbolTree) { let symbol_id = self.add_symbol(symbol); self.symbol_stack.push(symbol_id); @@ -501,7 +523,6 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { self.pop_symbol(); } - ast::Stmt::ClassDef(class_def) => { let symbol = SymbolTree { parent: None, @@ -521,7 +542,6 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { source_order::walk_stmt(self, stmt); self.pop_symbol(); } - ast::Stmt::Assign(assign) => { // Include assignments only when we're in global or class scope if self.in_function { @@ -531,28 +551,9 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { let ast::Expr::Name(name) = target else { continue; }; - let kind = if Self::is_constant_name(name.id.as_str()) { - SymbolKind::Constant - } else if self - .iter_symbol_stack() - .any(|s| s.kind == SymbolKind::Class) - { - SymbolKind::Field - } else { - SymbolKind::Variable - }; - - let symbol = SymbolTree { - parent: None, - name: name.id.to_string(), - kind, - name_range: name.range(), - full_range: stmt.range(), - }; - self.add_symbol(symbol); + self.add_assignment(stmt, name); } } - ast::Stmt::AnnAssign(ann_assign) => { // Include assignments only when we're in global or class scope if self.in_function { @@ -561,27 +562,8 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { let ast::Expr::Name(name) = &*ann_assign.target else { return; }; - let kind = if Self::is_constant_name(name.id.as_str()) { - SymbolKind::Constant - } else if self - .iter_symbol_stack() - .any(|s| s.kind == SymbolKind::Class) - { - SymbolKind::Field - } else { - SymbolKind::Variable - }; - - let symbol = SymbolTree { - parent: None, - name: name.id.to_string(), - kind, - name_range: name.range(), - full_range: stmt.range(), - }; - self.add_symbol(symbol); + self.add_assignment(stmt, name); } - _ => { source_order::walk_stmt(self, stmt); } @@ -591,9 +573,16 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { #[cfg(test)] mod tests { - fn matches(query: &str, symbol: &str) -> bool { - super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol) - } + use camino::Utf8Component; + use insta::internals::SettingsBindDropGuard; + + use ruff_db::Db; + use ruff_db::files::{FileRootKind, system_path_to_file}; + use ruff_db::system::{DbWithWritableSystem, SystemPath, SystemPathBuf}; + use ruff_python_trivia::textwrap::dedent; + use ty_project::{ProjectMetadata, TestDb}; + + use super::symbols_for_file_global_only; #[test] fn various_yes() { @@ -625,4 +614,191 @@ mod tests { assert!(!matches("abcd", "abc")); assert!(!matches("δΘπ", "θΔΠ")); } + + #[test] + fn exports_simple() { + insta::assert_snapshot!( + public_test("\ +FOO = 1 +foo = 1 +frob: int = 1 +class Foo: + BAR = 1 +def quux(): + baz = 1 +").exports(), + @r" + FOO :: Constant + foo :: Variable + frob :: Variable + Foo :: Class + quux :: Function + ", + ); + } + + #[test] + fn exports_conditional_true() { + insta::assert_snapshot!( + public_test("\ +foo = 1 +if True: + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_conditional_false() { + // FIXME: This shouldn't include `bar`. + insta::assert_snapshot!( + public_test("\ +foo = 1 +if False: + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_conditional_sys_version() { + // FIXME: This shouldn't include `bar`. + insta::assert_snapshot!( + public_test("\ +import sys + +foo = 1 +if sys.version < (3, 5): + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_type_checking() { + insta::assert_snapshot!( + public_test("\ +from typing import TYPE_CHECKING + +foo = 1 +if TYPE_CHECKING: + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + fn matches(query: &str, symbol: &str) -> bool { + super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol) + } + + fn public_test(code: &str) -> PublicTest { + PublicTestBuilder::default().source("test.py", code).build() + } + + struct PublicTest { + db: TestDb, + _insta_settings_guard: SettingsBindDropGuard, + } + + impl PublicTest { + /// Returns the exports from `test.py`. + /// + /// This is, conventionally, the default module file path used. For + /// example, it's used by the `public_test` convenience constructor. + fn exports(&self) -> String { + self.exports_for("test.py") + } + + /// Returns the exports from the module at the given path. + /// + /// The path given must have been written to this test's salsa DB. + fn exports_for(&self, path: impl AsRef) -> String { + let file = system_path_to_file(&self.db, path.as_ref()).unwrap(); + let symbols = symbols_for_file_global_only(&self.db, file); + symbols + .iter() + .map(|(_, symbol)| { + format!("{name} :: {kind:?}", name = symbol.name, kind = symbol.kind) + }) + .collect::>() + .join("\n") + } + } + + #[derive(Default)] + struct PublicTestBuilder { + /// A list of source files, corresponding to the + /// file's path and its contents. + sources: Vec, + } + + impl PublicTestBuilder { + pub(super) fn build(&self) -> PublicTest { + let mut db = TestDb::new(ProjectMetadata::new( + "test".into(), + SystemPathBuf::from("/"), + )); + + db.init_program().unwrap(); + + for Source { path, contents } in &self.sources { + db.write_file(path, contents) + .expect("write to memory file system to be successful"); + + // Add a root for the top-most component. + let top = path.components().find_map(|c| match c { + Utf8Component::Normal(c) => Some(c), + _ => None, + }); + if let Some(top) = top { + let top = SystemPath::new(top); + if db.system().is_directory(top) { + db.files() + .try_add_root(&db, top, FileRootKind::LibrarySearchPath); + } + } + } + + // N.B. We don't set anything custom yet, but we leave + // this here for when we invevitable add a filter. + let insta_settings = insta::Settings::clone_current(); + let insta_settings_guard = insta_settings.bind_to_scope(); + PublicTest { + db, + _insta_settings_guard: insta_settings_guard, + } + } + + pub(super) fn source( + &mut self, + path: impl Into, + contents: impl AsRef, + ) -> &mut PublicTestBuilder { + let path = path.into(); + let contents = dedent(contents.as_ref()).into_owned(); + self.sources.push(Source { path, contents }); + self + } + } + + struct Source { + path: SystemPathBuf, + contents: String, + } }