From 644096ea8a1e8ccc4d3d1300bcfb0785ec671947 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 2 Dec 2025 14:37:50 +0100 Subject: [PATCH] [ty] Fix find-references for import aliases (#21736) --- crates/ty_ide/src/find_references.rs | 148 ++++++++++++++++++ crates/ty_ide/src/goto.rs | 30 ++-- crates/ty_ide/src/rename.rs | 128 +++++++++++---- .../src/semantic_index/builder.rs | 4 + 4 files changed, 264 insertions(+), 46 deletions(-) diff --git a/crates/ty_ide/src/find_references.rs b/crates/ty_ide/src/find_references.rs index a6b60bc6c2..d281dcaf92 100644 --- a/crates/ty_ide/src/find_references.rs +++ b/crates/ty_ide/src/find_references.rs @@ -1722,4 +1722,152 @@ func_alias() | "###); } + + #[test] + fn import_alias() { + let test = CursorTest::builder() + .source( + "main.py", + r#" + import warnings + import warnings as abc + + x = abc + y = warnings + "#, + ) + .build(); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:3:20 + | + 2 | import warnings + 3 | import warnings as abc + | ^^^ + 4 | + 5 | x = abc + | + + info[references]: Reference 2 + --> main.py:5:5 + | + 3 | import warnings as abc + 4 | + 5 | x = abc + | ^^^ + 6 | y = warnings + | + "); + } + + #[test] + fn import_alias_use() { + let test = CursorTest::builder() + .source( + "main.py", + r#" + import warnings + import warnings as abc + + x = abc + y = warnings + "#, + ) + .build(); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:3:20 + | + 2 | import warnings + 3 | import warnings as abc + | ^^^ + 4 | + 5 | x = abc + | + + info[references]: Reference 2 + --> main.py:5:5 + | + 3 | import warnings as abc + 4 | + 5 | x = abc + | ^^^ + 6 | y = warnings + | + "); + } + + #[test] + fn import_from_alias() { + let test = CursorTest::builder() + .source( + "main.py", + r#" + from warnings import deprecated as xyz + from warnings import deprecated + + y = xyz + z = deprecated + "#, + ) + .build(); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:36 + | + 2 | from warnings import deprecated as xyz + | ^^^ + 3 | from warnings import deprecated + | + + info[references]: Reference 2 + --> main.py:5:5 + | + 3 | from warnings import deprecated + 4 | + 5 | y = xyz + | ^^^ + 6 | z = deprecated + | + "); + } + + #[test] + fn import_from_alias_use() { + let test = CursorTest::builder() + .source( + "main.py", + r#" + from warnings import deprecated as xyz + from warnings import deprecated + + y = xyz + z = deprecated + "#, + ) + .build(); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:36 + | + 2 | from warnings import deprecated as xyz + | ^^^ + 3 | from warnings import deprecated + | + + info[references]: Reference 2 + --> main.py:5:5 + | + 3 | from warnings import deprecated + 4 | + 5 | y = xyz + | ^^^ + 6 | z = deprecated + | + "); + } } diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index 3b086b91fd..17df9f11d9 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -396,13 +396,19 @@ impl GotoTarget<'_> { GotoTarget::ImportSymbolAlias { alias, import_from, .. } => { - let symbol_name = alias.name.as_str(); - Some(definitions_for_imported_symbol( - model, - import_from, - symbol_name, - alias_resolution, - )) + if let Some(asname) = alias.asname.as_ref() + && alias_resolution == ImportAliasResolution::PreserveAliases + { + Some(definitions_for_name(model, asname.as_str(), asname.into())) + } else { + let symbol_name = alias.name.as_str(); + Some(definitions_for_imported_symbol( + model, + import_from, + symbol_name, + alias_resolution, + )) + } } GotoTarget::ImportModuleComponent { @@ -418,12 +424,12 @@ impl GotoTarget<'_> { // Handle import aliases (offset within 'z' in "import x.y as z") GotoTarget::ImportModuleAlias { alias } => { - if alias_resolution == ImportAliasResolution::ResolveAliases { - definitions_for_module(model, Some(alias.name.as_str()), 0) + if let Some(asname) = alias.asname.as_ref() + && alias_resolution == ImportAliasResolution::PreserveAliases + { + Some(definitions_for_name(model, asname.as_str(), asname.into())) } else { - alias.asname.as_ref().map(|name| { - definitions_for_name(model, name.as_str(), AnyNodeRef::Identifier(name)) - }) + definitions_for_module(model, Some(alias.name.as_str()), 0) } } diff --git a/crates/ty_ide/src/rename.rs b/crates/ty_ide/src/rename.rs index 156f38fee4..3ecc474d6d 100644 --- a/crates/ty_ide/src/rename.rs +++ b/crates/ty_ide/src/rename.rs @@ -163,7 +163,7 @@ mod tests { } #[test] - fn test_prepare_rename_parameter() { + fn prepare_rename_parameter() { let test = cursor_test( " def func(value: int) -> int: @@ -178,7 +178,7 @@ value = 0 } #[test] - fn test_rename_parameter() { + fn rename_parameter() { let test = cursor_test( " def func(value: int) -> int: @@ -207,7 +207,7 @@ func(value=42) } #[test] - fn test_rename_function() { + fn rename_function() { let test = cursor_test( " def func(): @@ -235,7 +235,7 @@ x = func } #[test] - fn test_rename_class() { + fn rename_class() { let test = cursor_test( " class MyClass: @@ -265,7 +265,7 @@ cls = MyClass } #[test] - fn test_rename_invalid_name() { + fn rename_invalid_name() { let test = cursor_test( " def func(): @@ -286,7 +286,7 @@ def func(): } #[test] - fn test_multi_file_function_rename() { + fn multi_file_function_rename() { let test = CursorTest::builder() .source( "utils.py", @@ -312,7 +312,7 @@ from utils import helper_function class DataProcessor: def __init__(self): self.multiplier = helper_function - + def process(self, value): return helper_function(value) ", @@ -654,7 +654,7 @@ class DataProcessor: def __init__(self, pos, btn): self.position: int = pos self.button: str = btn - + def my_func(event: Click): match event: case Click(x, button=ab): @@ -685,7 +685,7 @@ class DataProcessor: def __init__(self, pos, btn): self.position: int = pos self.button: str = btn - + def my_func(event: Click): match event: case Click(x, button=ab): @@ -716,7 +716,7 @@ class DataProcessor: def __init__(self, pos, btn): self.position: int = pos self.button: str = btn - + def my_func(event: Click): match event: case Click(x, button=ab): @@ -756,7 +756,7 @@ class DataProcessor: def __init__(self, pos, btn): self.position: int = pos self.button: str = btn - + def my_func(event: Click): match event: case Click(x, button=ab): @@ -880,7 +880,7 @@ class DataProcessor: } #[test] - fn test_cannot_rename_import_module_component() { + fn cannot_rename_import_module_component() { // Test that we cannot rename parts of module names in import statements let test = cursor_test( " @@ -893,7 +893,7 @@ x = os.path.join('a', 'b') } #[test] - fn test_cannot_rename_from_import_module_component() { + fn cannot_rename_from_import_module_component() { // Test that we cannot rename parts of module names in from import statements let test = cursor_test( " @@ -906,7 +906,7 @@ result = join('a', 'b') } #[test] - fn test_cannot_rename_external_file() { + fn cannot_rename_external_file() { // This test verifies that we cannot rename a symbol when it's defined in a file // that's outside the project (like a standard library function) let test = cursor_test( @@ -920,7 +920,7 @@ x = os.path.join('a', 'b') } #[test] - fn test_rename_alias_at_import_statement() { + fn rename_alias_at_import_statement() { let test = CursorTest::builder() .source( "utils.py", @@ -931,8 +931,8 @@ def test(): pass .source( "main.py", " -from utils import test as test_alias -result = test_alias() +from utils import test as alias +result = alias() ", ) .build(); @@ -941,16 +941,16 @@ result = test_alias() info[rename]: Rename symbol (found 2 locations) --> main.py:2:27 | - 2 | from utils import test as test_alias - | ^^^^^^^^^^ - 3 | result = test_alias() - | ---------- + 2 | from utils import test as alias + | ^^^^^ + 3 | result = alias() + | ----- | "); } #[test] - fn test_rename_alias_at_usage_site() { + fn rename_alias_at_usage_site() { // Test renaming an alias when the cursor is on the alias in the usage statement let test = CursorTest::builder() .source( @@ -962,8 +962,8 @@ def test(): pass .source( "main.py", " -from utils import test as test_alias -result = test_alias() +from utils import test as alias +result = alias() ", ) .build(); @@ -972,16 +972,16 @@ result = test_alias() info[rename]: Rename symbol (found 2 locations) --> main.py:2:27 | - 2 | from utils import test as test_alias - | ^^^^^^^^^^ - 3 | result = test_alias() - | ---------- + 2 | from utils import test as alias + | ^^^^^ + 3 | result = alias() + | ----- | "); } #[test] - fn test_rename_across_import_chain_with_mixed_aliases() { + fn rename_across_import_chain_with_mixed_aliases() { // Test renaming a symbol that's imported across multiple files with mixed alias patterns // File 1 (source.py): defines the original function // File 2 (middle.py): imports without alias from source.py @@ -1049,7 +1049,7 @@ value1 = func_alias() } #[test] - fn test_rename_alias_in_import_chain() { + fn rename_alias_in_import_chain() { let test = CursorTest::builder() .source( "file1.py", @@ -1101,7 +1101,7 @@ class App: } #[test] - fn test_cannot_rename_keyword() { + fn cannot_rename_keyword() { // Test that we cannot rename Python keywords like "None" let test = cursor_test( " @@ -1116,7 +1116,7 @@ def process_value(value): } #[test] - fn test_cannot_rename_builtin_type() { + fn cannot_rename_builtin_type() { // Test that we cannot rename Python builtin types like "int" let test = cursor_test( " @@ -1129,7 +1129,7 @@ def convert_to_number(value): } #[test] - fn test_rename_keyword_argument() { + fn rename_keyword_argument() { // Test renaming a keyword argument and its corresponding parameter let test = cursor_test( " @@ -1156,7 +1156,7 @@ result = func(10, y=20) } #[test] - fn test_rename_parameter_with_keyword_argument() { + fn rename_parameter_with_keyword_argument() { // Test renaming a parameter and its corresponding keyword argument let test = cursor_test( " @@ -1181,4 +1181,64 @@ result = func(10, y=20) | "); } + + #[test] + fn import_alias() { + let test = CursorTest::builder() + .source( + "main.py", + r#" + import warnings + import warnings as abc + + x = abc + y = warnings + "#, + ) + .build(); + + assert_snapshot!(test.rename("z"), @r" + info[rename]: Rename symbol (found 2 locations) + --> main.py:3:20 + | + 2 | import warnings + 3 | import warnings as abc + | ^^^ + 4 | + 5 | x = abc + | --- + 6 | y = warnings + | + "); + } + + #[test] + fn import_alias_use() { + let test = CursorTest::builder() + .source( + "main.py", + r#" + import warnings + import warnings as abc + + x = abc + y = warnings + "#, + ) + .build(); + + assert_snapshot!(test.rename("z"), @r" + info[rename]: Rename symbol (found 2 locations) + --> main.py:3:20 + | + 2 | import warnings + 3 | import warnings as abc + | ^^^ + 4 | + 5 | x = abc + | --- + 6 | y = warnings + | + "); + } } diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index f7b6da1a0f..b729862f2b 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -1478,6 +1478,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { } let (symbol_name, is_reexported) = if let Some(asname) = &alias.asname { + self.scopes_by_expression + .record_expression(asname, self.current_scope()); (asname.id.clone(), asname.id == alias.name.id) } else { (Name::new(alias.name.id.split('.').next().unwrap()), false) @@ -1651,6 +1653,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { } let (symbol_name, is_reexported) = if let Some(asname) = &alias.asname { + self.scopes_by_expression + .record_expression(asname, self.current_scope()); // It's re-exported if it's `from ... import x as x` (&asname.id, asname.id == alias.name.id) } else {