diff --git a/crates/ty_ide/src/symbols.rs b/crates/ty_ide/src/symbols.rs index ebb6637ae3..a49ac318fc 100644 --- a/crates/ty_ide/src/symbols.rs +++ b/crates/ty_ide/src/symbols.rs @@ -436,6 +436,28 @@ impl SymbolVisitor { symbol_id } + fn add_assignment(&mut self, stmt: &ast::Stmt, name: &ast::ExprName) -> SymbolId { + let kind = if Self::is_constant_name(name.id.as_str()) { + SymbolKind::Constant + } else if self + .iter_symbol_stack() + .any(|s| s.kind == SymbolKind::Class) + { + SymbolKind::Field + } else { + SymbolKind::Variable + }; + + let symbol = SymbolTree { + parent: None, + name: name.id.to_string(), + kind, + name_range: name.range(), + full_range: stmt.range(), + }; + self.add_symbol(symbol) + } + fn push_symbol(&mut self, symbol: SymbolTree) { let symbol_id = self.add_symbol(symbol); self.symbol_stack.push(symbol_id); @@ -501,7 +523,6 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { self.pop_symbol(); } - ast::Stmt::ClassDef(class_def) => { let symbol = SymbolTree { parent: None, @@ -521,7 +542,6 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { source_order::walk_stmt(self, stmt); self.pop_symbol(); } - ast::Stmt::Assign(assign) => { // Include assignments only when we're in global or class scope if self.in_function { @@ -531,28 +551,9 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { let ast::Expr::Name(name) = target else { continue; }; - let kind = if Self::is_constant_name(name.id.as_str()) { - SymbolKind::Constant - } else if self - .iter_symbol_stack() - .any(|s| s.kind == SymbolKind::Class) - { - SymbolKind::Field - } else { - SymbolKind::Variable - }; - - let symbol = SymbolTree { - parent: None, - name: name.id.to_string(), - kind, - name_range: name.range(), - full_range: stmt.range(), - }; - self.add_symbol(symbol); + self.add_assignment(stmt, name); } } - ast::Stmt::AnnAssign(ann_assign) => { // Include assignments only when we're in global or class scope if self.in_function { @@ -561,27 +562,8 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { let ast::Expr::Name(name) = &*ann_assign.target else { return; }; - let kind = if Self::is_constant_name(name.id.as_str()) { - SymbolKind::Constant - } else if self - .iter_symbol_stack() - .any(|s| s.kind == SymbolKind::Class) - { - SymbolKind::Field - } else { - SymbolKind::Variable - }; - - let symbol = SymbolTree { - parent: None, - name: name.id.to_string(), - kind, - name_range: name.range(), - full_range: stmt.range(), - }; - self.add_symbol(symbol); + self.add_assignment(stmt, name); } - _ => { source_order::walk_stmt(self, stmt); } @@ -591,9 +573,16 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { #[cfg(test)] mod tests { - fn matches(query: &str, symbol: &str) -> bool { - super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol) - } + use camino::Utf8Component; + use insta::internals::SettingsBindDropGuard; + + use ruff_db::Db; + use ruff_db::files::{FileRootKind, system_path_to_file}; + use ruff_db::system::{DbWithWritableSystem, SystemPath, SystemPathBuf}; + use ruff_python_trivia::textwrap::dedent; + use ty_project::{ProjectMetadata, TestDb}; + + use super::symbols_for_file_global_only; #[test] fn various_yes() { @@ -625,4 +614,191 @@ mod tests { assert!(!matches("abcd", "abc")); assert!(!matches("δΘπ", "θΔΠ")); } + + #[test] + fn exports_simple() { + insta::assert_snapshot!( + public_test("\ +FOO = 1 +foo = 1 +frob: int = 1 +class Foo: + BAR = 1 +def quux(): + baz = 1 +").exports(), + @r" + FOO :: Constant + foo :: Variable + frob :: Variable + Foo :: Class + quux :: Function + ", + ); + } + + #[test] + fn exports_conditional_true() { + insta::assert_snapshot!( + public_test("\ +foo = 1 +if True: + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_conditional_false() { + // FIXME: This shouldn't include `bar`. + insta::assert_snapshot!( + public_test("\ +foo = 1 +if False: + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_conditional_sys_version() { + // FIXME: This shouldn't include `bar`. + insta::assert_snapshot!( + public_test("\ +import sys + +foo = 1 +if sys.version < (3, 5): + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + #[test] + fn exports_type_checking() { + insta::assert_snapshot!( + public_test("\ +from typing import TYPE_CHECKING + +foo = 1 +if TYPE_CHECKING: + bar = 1 +").exports(), + @r" + foo :: Variable + bar :: Variable + ", + ); + } + + fn matches(query: &str, symbol: &str) -> bool { + super::QueryPattern::fuzzy(query).is_match_symbol_name(symbol) + } + + fn public_test(code: &str) -> PublicTest { + PublicTestBuilder::default().source("test.py", code).build() + } + + struct PublicTest { + db: TestDb, + _insta_settings_guard: SettingsBindDropGuard, + } + + impl PublicTest { + /// Returns the exports from `test.py`. + /// + /// This is, conventionally, the default module file path used. For + /// example, it's used by the `public_test` convenience constructor. + fn exports(&self) -> String { + self.exports_for("test.py") + } + + /// Returns the exports from the module at the given path. + /// + /// The path given must have been written to this test's salsa DB. + fn exports_for(&self, path: impl AsRef) -> String { + let file = system_path_to_file(&self.db, path.as_ref()).unwrap(); + let symbols = symbols_for_file_global_only(&self.db, file); + symbols + .iter() + .map(|(_, symbol)| { + format!("{name} :: {kind:?}", name = symbol.name, kind = symbol.kind) + }) + .collect::>() + .join("\n") + } + } + + #[derive(Default)] + struct PublicTestBuilder { + /// A list of source files, corresponding to the + /// file's path and its contents. + sources: Vec, + } + + impl PublicTestBuilder { + pub(super) fn build(&self) -> PublicTest { + let mut db = TestDb::new(ProjectMetadata::new( + "test".into(), + SystemPathBuf::from("/"), + )); + + db.init_program().unwrap(); + + for Source { path, contents } in &self.sources { + db.write_file(path, contents) + .expect("write to memory file system to be successful"); + + // Add a root for the top-most component. + let top = path.components().find_map(|c| match c { + Utf8Component::Normal(c) => Some(c), + _ => None, + }); + if let Some(top) = top { + let top = SystemPath::new(top); + if db.system().is_directory(top) { + db.files() + .try_add_root(&db, top, FileRootKind::LibrarySearchPath); + } + } + } + + // N.B. We don't set anything custom yet, but we leave + // this here for when we invevitable add a filter. + let insta_settings = insta::Settings::clone_current(); + let insta_settings_guard = insta_settings.bind_to_scope(); + PublicTest { + db, + _insta_settings_guard: insta_settings_guard, + } + } + + pub(super) fn source( + &mut self, + path: impl Into, + contents: impl AsRef, + ) -> &mut PublicTestBuilder { + let path = path.into(); + let contents = dedent(contents.as_ref()).into_owned(); + self.sources.push(Source { path, contents }); + self + } + } + + struct Source { + path: SystemPathBuf, + contents: String, + } }