[ty] Fix find-references for import aliases (#21736)

This commit is contained in:
Micha Reiser 2025-12-02 14:37:50 +01:00 committed by GitHub
parent 015ab9e576
commit 644096ea8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 264 additions and 46 deletions

View File

@ -1722,4 +1722,152 @@ func<CURSOR>_alias()
|
"###);
}
#[test]
fn import_alias() {
let test = CursorTest::builder()
.source(
"main.py",
r#"
import warnings
import warnings as <CURSOR>abc
x = abc
y = warnings
"#,
)
.build();
assert_snapshot!(test.references(), @r"
info[references]: Reference 1
--> main.py:3:20
|
2 | import warnings
3 | import warnings as abc
| ^^^
4 |
5 | x = abc
|
info[references]: Reference 2
--> main.py:5:5
|
3 | import warnings as abc
4 |
5 | x = abc
| ^^^
6 | y = warnings
|
");
}
#[test]
fn import_alias_use() {
let test = CursorTest::builder()
.source(
"main.py",
r#"
import warnings
import warnings as abc
x = abc<CURSOR>
y = warnings
"#,
)
.build();
assert_snapshot!(test.references(), @r"
info[references]: Reference 1
--> main.py:3:20
|
2 | import warnings
3 | import warnings as abc
| ^^^
4 |
5 | x = abc
|
info[references]: Reference 2
--> main.py:5:5
|
3 | import warnings as abc
4 |
5 | x = abc
| ^^^
6 | y = warnings
|
");
}
#[test]
fn import_from_alias() {
let test = CursorTest::builder()
.source(
"main.py",
r#"
from warnings import deprecated as xyz<CURSOR>
from warnings import deprecated
y = xyz
z = deprecated
"#,
)
.build();
assert_snapshot!(test.references(), @r"
info[references]: Reference 1
--> main.py:2:36
|
2 | from warnings import deprecated as xyz
| ^^^
3 | from warnings import deprecated
|
info[references]: Reference 2
--> main.py:5:5
|
3 | from warnings import deprecated
4 |
5 | y = xyz
| ^^^
6 | z = deprecated
|
");
}
#[test]
fn import_from_alias_use() {
let test = CursorTest::builder()
.source(
"main.py",
r#"
from warnings import deprecated as xyz
from warnings import deprecated
y = xyz<CURSOR>
z = deprecated
"#,
)
.build();
assert_snapshot!(test.references(), @r"
info[references]: Reference 1
--> main.py:2:36
|
2 | from warnings import deprecated as xyz
| ^^^
3 | from warnings import deprecated
|
info[references]: Reference 2
--> main.py:5:5
|
3 | from warnings import deprecated
4 |
5 | y = xyz
| ^^^
6 | z = deprecated
|
");
}
}

View File

@ -396,6 +396,11 @@ impl GotoTarget<'_> {
GotoTarget::ImportSymbolAlias {
alias, import_from, ..
} => {
if let Some(asname) = alias.asname.as_ref()
&& alias_resolution == ImportAliasResolution::PreserveAliases
{
Some(definitions_for_name(model, asname.as_str(), asname.into()))
} else {
let symbol_name = alias.name.as_str();
Some(definitions_for_imported_symbol(
model,
@ -404,6 +409,7 @@ impl GotoTarget<'_> {
alias_resolution,
))
}
}
GotoTarget::ImportModuleComponent {
module_name,
@ -418,12 +424,12 @@ impl GotoTarget<'_> {
// Handle import aliases (offset within 'z' in "import x.y as z")
GotoTarget::ImportModuleAlias { alias } => {
if alias_resolution == ImportAliasResolution::ResolveAliases {
definitions_for_module(model, Some(alias.name.as_str()), 0)
if let Some(asname) = alias.asname.as_ref()
&& alias_resolution == ImportAliasResolution::PreserveAliases
{
Some(definitions_for_name(model, asname.as_str(), asname.into()))
} else {
alias.asname.as_ref().map(|name| {
definitions_for_name(model, name.as_str(), AnyNodeRef::Identifier(name))
})
definitions_for_module(model, Some(alias.name.as_str()), 0)
}
}

View File

@ -163,7 +163,7 @@ mod tests {
}
#[test]
fn test_prepare_rename_parameter() {
fn prepare_rename_parameter() {
let test = cursor_test(
"
def func(<CURSOR>value: int) -> int:
@ -178,7 +178,7 @@ value = 0
}
#[test]
fn test_rename_parameter() {
fn rename_parameter() {
let test = cursor_test(
"
def func(<CURSOR>value: int) -> int:
@ -207,7 +207,7 @@ func(value=42)
}
#[test]
fn test_rename_function() {
fn rename_function() {
let test = cursor_test(
"
def fu<CURSOR>nc():
@ -235,7 +235,7 @@ x = func
}
#[test]
fn test_rename_class() {
fn rename_class() {
let test = cursor_test(
"
class My<CURSOR>Class:
@ -265,7 +265,7 @@ cls = MyClass
}
#[test]
fn test_rename_invalid_name() {
fn rename_invalid_name() {
let test = cursor_test(
"
def fu<CURSOR>nc():
@ -286,7 +286,7 @@ def fu<CURSOR>nc():
}
#[test]
fn test_multi_file_function_rename() {
fn multi_file_function_rename() {
let test = CursorTest::builder()
.source(
"utils.py",
@ -880,7 +880,7 @@ class DataProcessor:
}
#[test]
fn test_cannot_rename_import_module_component() {
fn cannot_rename_import_module_component() {
// Test that we cannot rename parts of module names in import statements
let test = cursor_test(
"
@ -893,7 +893,7 @@ x = os.path.join('a', 'b')
}
#[test]
fn test_cannot_rename_from_import_module_component() {
fn cannot_rename_from_import_module_component() {
// Test that we cannot rename parts of module names in from import statements
let test = cursor_test(
"
@ -906,7 +906,7 @@ result = join('a', 'b')
}
#[test]
fn test_cannot_rename_external_file() {
fn cannot_rename_external_file() {
// This test verifies that we cannot rename a symbol when it's defined in a file
// that's outside the project (like a standard library function)
let test = cursor_test(
@ -920,7 +920,7 @@ x = <CURSOR>os.path.join('a', 'b')
}
#[test]
fn test_rename_alias_at_import_statement() {
fn rename_alias_at_import_statement() {
let test = CursorTest::builder()
.source(
"utils.py",
@ -931,8 +931,8 @@ def test(): pass
.source(
"main.py",
"
from utils import test as test_<CURSOR>alias
result = test_alias()
from utils import test as <CURSOR>alias
result = alias()
",
)
.build();
@ -941,16 +941,16 @@ result = test_alias()
info[rename]: Rename symbol (found 2 locations)
--> main.py:2:27
|
2 | from utils import test as test_alias
| ^^^^^^^^^^
3 | result = test_alias()
| ----------
2 | from utils import test as alias
| ^^^^^
3 | result = alias()
| -----
|
");
}
#[test]
fn test_rename_alias_at_usage_site() {
fn rename_alias_at_usage_site() {
// Test renaming an alias when the cursor is on the alias in the usage statement
let test = CursorTest::builder()
.source(
@ -962,8 +962,8 @@ def test(): pass
.source(
"main.py",
"
from utils import test as test_alias
result = test_<CURSOR>alias()
from utils import test as alias
result = <CURSOR>alias()
",
)
.build();
@ -972,16 +972,16 @@ result = test_<CURSOR>alias()
info[rename]: Rename symbol (found 2 locations)
--> main.py:2:27
|
2 | from utils import test as test_alias
| ^^^^^^^^^^
3 | result = test_alias()
| ----------
2 | from utils import test as alias
| ^^^^^
3 | result = alias()
| -----
|
");
}
#[test]
fn test_rename_across_import_chain_with_mixed_aliases() {
fn rename_across_import_chain_with_mixed_aliases() {
// Test renaming a symbol that's imported across multiple files with mixed alias patterns
// File 1 (source.py): defines the original function
// File 2 (middle.py): imports without alias from source.py
@ -1049,7 +1049,7 @@ value1 = func_alias()
}
#[test]
fn test_rename_alias_in_import_chain() {
fn rename_alias_in_import_chain() {
let test = CursorTest::builder()
.source(
"file1.py",
@ -1101,7 +1101,7 @@ class App:
}
#[test]
fn test_cannot_rename_keyword() {
fn cannot_rename_keyword() {
// Test that we cannot rename Python keywords like "None"
let test = cursor_test(
"
@ -1116,7 +1116,7 @@ def process_value(value):
}
#[test]
fn test_cannot_rename_builtin_type() {
fn cannot_rename_builtin_type() {
// Test that we cannot rename Python builtin types like "int"
let test = cursor_test(
"
@ -1129,7 +1129,7 @@ def convert_to_number(value):
}
#[test]
fn test_rename_keyword_argument() {
fn rename_keyword_argument() {
// Test renaming a keyword argument and its corresponding parameter
let test = cursor_test(
"
@ -1156,7 +1156,7 @@ result = func(10, <CURSOR>y=20)
}
#[test]
fn test_rename_parameter_with_keyword_argument() {
fn rename_parameter_with_keyword_argument() {
// Test renaming a parameter and its corresponding keyword argument
let test = cursor_test(
"
@ -1181,4 +1181,64 @@ result = func(10, y=20)
|
");
}
#[test]
fn import_alias() {
let test = CursorTest::builder()
.source(
"main.py",
r#"
import warnings
import warnings as <CURSOR>abc
x = abc
y = warnings
"#,
)
.build();
assert_snapshot!(test.rename("z"), @r"
info[rename]: Rename symbol (found 2 locations)
--> main.py:3:20
|
2 | import warnings
3 | import warnings as abc
| ^^^
4 |
5 | x = abc
| ---
6 | y = warnings
|
");
}
#[test]
fn import_alias_use() {
let test = CursorTest::builder()
.source(
"main.py",
r#"
import warnings
import warnings as abc
x = abc<CURSOR>
y = warnings
"#,
)
.build();
assert_snapshot!(test.rename("z"), @r"
info[rename]: Rename symbol (found 2 locations)
--> main.py:3:20
|
2 | import warnings
3 | import warnings as abc
| ^^^
4 |
5 | x = abc
| ---
6 | y = warnings
|
");
}
}

View File

@ -1478,6 +1478,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
}
let (symbol_name, is_reexported) = if let Some(asname) = &alias.asname {
self.scopes_by_expression
.record_expression(asname, self.current_scope());
(asname.id.clone(), asname.id == alias.name.id)
} else {
(Name::new(alias.name.id.split('.').next().unwrap()), false)
@ -1651,6 +1653,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
}
let (symbol_name, is_reexported) = if let Some(asname) = &alias.asname {
self.scopes_by_expression
.record_expression(asname, self.current_scope());
// It's re-exported if it's `from ... import x as x`
(&asname.id, asname.id == alias.name.id)
} else {