diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index 18af1da727..d3f7a9ac27 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -4058,6 +4058,85 @@ def f[T](x: T): test.build().contains("__repr__"); } + #[test] + fn reexport_simple_import_noauto() { + let snapshot = CursorTest::builder() + .source( + "main.py", + r#" +import foo +foo.ZQ +"#, + ) + .source("foo.py", r#"from bar import ZQZQ"#) + .source("bar.py", r#"ZQZQ = 1"#) + .completion_test_builder() + .module_names() + .build() + .snapshot(); + assert_snapshot!(snapshot, @"ZQZQ :: Current module"); + } + + #[test] + fn reexport_simple_import_auto() { + let snapshot = CursorTest::builder() + .source( + "main.py", + r#" +ZQ +"#, + ) + .source("foo.py", r#"from bar import ZQZQ"#) + .source("bar.py", r#"ZQZQ = 1"#) + .completion_test_builder() + .auto_import() + .module_names() + .build() + .snapshot(); + assert_snapshot!(snapshot, @"ZQZQ :: bar"); + } + + #[test] + fn reexport_redundant_convention_import_noauto() { + let snapshot = CursorTest::builder() + .source( + "main.py", + r#" +import foo +foo.ZQ +"#, + ) + .source("foo.py", r#"from bar import ZQZQ as ZQZQ"#) + .source("bar.py", r#"ZQZQ = 1"#) + .completion_test_builder() + .module_names() + .build() + .snapshot(); + assert_snapshot!(snapshot, @"ZQZQ :: Current module"); + } + + #[test] + fn reexport_redundant_convention_import_auto() { + let snapshot = CursorTest::builder() + .source( + "main.py", + r#" +ZQ +"#, + ) + .source("foo.py", r#"from bar import ZQZQ as ZQZQ"#) + .source("bar.py", r#"ZQZQ = 1"#) + .completion_test_builder() + .auto_import() + .module_names() + .build() + .snapshot(); + assert_snapshot!(snapshot, @r" + ZQZQ :: bar + ZQZQ :: foo + "); + } + /// A way to create a simple single-file (named `main.py`) completion test /// builder. /// diff --git a/crates/ty_ide/src/symbols.rs b/crates/ty_ide/src/symbols.rs index 86a1b83e53..4bea3577d3 100644 --- a/crates/ty_ide/src/symbols.rs +++ b/crates/ty_ide/src/symbols.rs @@ -334,7 +334,7 @@ pub(crate) fn symbols_for_file(db: &dyn Db, file: File) -> FlatSymbols { }; visitor.visit_body(&module.syntax().body); FlatSymbols { - symbols: visitor.symbols, + symbols: visitor.into_symbols(), } } @@ -356,7 +356,7 @@ pub(crate) fn symbols_for_file_global_only(db: &dyn Db, file: File) -> FlatSymbo }; visitor.visit_body(&module.syntax().body); FlatSymbols { - symbols: visitor.symbols, + symbols: visitor.into_symbols(), } } @@ -367,6 +367,13 @@ struct SymbolTree { kind: SymbolKind, name_range: TextRange, full_range: TextRange, + re_export: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, get_size2::GetSize)] +enum ReExport { + Normal, + RedundantAlias, } /// A visitor over all symbols in a single file. @@ -382,6 +389,44 @@ struct SymbolVisitor { } impl SymbolVisitor { + fn into_symbols(self) -> IndexVec { + // We want to filter out some of the symbols we collected. + // But, we always assigned IDs to each symbol based on + // their position in a sequence. So when we filter some + // out, we need to remap the identifiers. + // + // N.B. This can be skipped when `global_only` is true, + // since in that case, none of the symbols have a parent + // ID by construction. + let mut remap = IndexVec::with_capacity(self.symbols.len()); + let mut new = IndexVec::with_capacity(self.symbols.len()); + for mut symbol in self.symbols { + if symbol.re_export == Some(ReExport::Normal) { + remap.push(None); + continue; + } + if let Some(ref mut parent) = symbol.parent { + // OK because the visitor guarantees that + // all parents have IDs less than their + // children. So its ID has already been + // remapped. + if let Some(new_parent) = remap[*parent] { + *parent = new_parent; + } else { + // The parent symbol was dropped, so + // all of its children should be as + // well. + remap.push(None); + continue; + } + } + let new_id = new.next_index(); + remap.push(Some(new_id)); + new.push(symbol); + } + new + } + fn visit_body(&mut self, body: &[ast::Stmt]) { for stmt in body { self.visit_stmt(stmt); @@ -401,6 +446,30 @@ impl SymbolVisitor { symbol_id } + fn add_import_alias(&mut self, stmt: &ast::Stmt, alias: &ast::Alias) -> SymbolId { + let name = alias.asname.as_ref().unwrap_or(&alias.name); + let kind = if Self::is_constant_name(name.as_str()) { + SymbolKind::Constant + } else { + SymbolKind::Variable + }; + let re_export = Some( + if alias.asname.as_ref().map(ast::Identifier::as_str) == Some(alias.name.as_str()) { + ReExport::RedundantAlias + } else { + ReExport::Normal + }, + ); + self.add_symbol(SymbolTree { + parent: None, + name: name.id.to_string(), + kind, + name_range: name.range(), + full_range: stmt.range(), + re_export, + }) + } + fn push_symbol(&mut self, symbol: SymbolTree) { let symbol_id = self.add_symbol(symbol); self.symbol_stack.push(symbol_id); @@ -445,6 +514,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { kind, name_range: func_def.name.range(), full_range: stmt.range(), + re_export: None, }; if self.global_only { @@ -474,6 +544,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { kind: SymbolKind::Class, name_range: class_def.name.range(), full_range: stmt.range(), + re_export: None, }; if self.global_only { @@ -513,6 +584,7 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { kind, name_range: name.range(), full_range: stmt.range(), + re_export: None, }; self.add_symbol(symbol); } @@ -543,10 +615,31 @@ impl SourceOrderVisitor<'_> for SymbolVisitor { kind, name_range: name.range(), full_range: stmt.range(), + re_export: None, }; self.add_symbol(symbol); } + ast::Stmt::Import(import) => { + // We only consider imports in global scope. + if self.in_function { + return; + } + for alias in &import.names { + self.add_import_alias(stmt, alias); + } + } + + ast::Stmt::ImportFrom(import_from) => { + // We only consider imports in global scope. + if self.in_function { + return; + } + for alias in &import_from.names { + self.add_import_alias(stmt, alias); + } + } + _ => { source_order::walk_stmt(self, stmt); }