diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index 359438bca3..217f8b420b 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -73,19 +73,29 @@ pub(crate) enum GotoTarget<'a> { /// ``` ImportModuleAlias { alias: &'a ast::Alias, + asname: &'a ast::Identifier, + }, + + /// In an import statement, the named under which the symbol is exported + /// in the imported file. + /// + /// ```py + /// from foo import bar as baz + /// ^^^ + /// ``` + ImportExportedName { + alias: &'a ast::Alias, + import_from: &'a ast::StmtImportFrom, }, /// Import alias in from import statement /// ```py /// from foo import bar as baz - /// ^^^ - /// from foo import bar as baz /// ^^^ /// ``` ImportSymbolAlias { alias: &'a ast::Alias, - range: TextRange, - import_from: &'a ast::StmtImportFrom, + asname: &'a ast::Identifier, }, /// Go to on the exception handler variable @@ -290,8 +300,9 @@ impl GotoTarget<'_> { GotoTarget::FunctionDef(function) => function.inferred_type(model), GotoTarget::ClassDef(class) => class.inferred_type(model), GotoTarget::Parameter(parameter) => parameter.inferred_type(model), - GotoTarget::ImportSymbolAlias { alias, .. } => alias.inferred_type(model), - GotoTarget::ImportModuleAlias { alias } => alias.inferred_type(model), + GotoTarget::ImportSymbolAlias { alias, .. } + | GotoTarget::ImportModuleAlias { alias, .. } + | GotoTarget::ImportExportedName { alias, .. } => alias.inferred_type(model), GotoTarget::ExceptVariable(except) => except.inferred_type(model), GotoTarget::KeywordArgument { keyword, .. } => keyword.value.inferred_type(model), // When asking the type of a callable, usually you want the callable itself? @@ -378,7 +389,9 @@ impl GotoTarget<'_> { alias_resolution: ImportAliasResolution, ) -> Option> { let definitions = match self { - GotoTarget::Expression(expression) => definitions_for_expression(model, *expression), + GotoTarget::Expression(expression) => { + definitions_for_expression(model, *expression, alias_resolution) + } // For already-defined symbols, they are their own definitions GotoTarget::FunctionDef(function) => Some(vec![ResolvedDefinition::Definition( function.definition(model), @@ -393,22 +406,21 @@ impl GotoTarget<'_> { )]), // For import aliases (offset within 'y' or 'z' in "from x import y as z") - GotoTarget::ImportSymbolAlias { - alias, import_from, .. - } => { - if let Some(asname) = alias.asname.as_ref() - && alias_resolution == ImportAliasResolution::PreserveAliases - { - Some(definitions_for_name(model, asname.as_str(), asname.into())) - } else { - let symbol_name = alias.name.as_str(); - Some(definitions_for_imported_symbol( - model, - import_from, - symbol_name, - alias_resolution, - )) - } + GotoTarget::ImportSymbolAlias { asname, .. } => Some(definitions_for_name( + model, + asname.as_str(), + AnyNodeRef::from(*asname), + alias_resolution, + )), + + GotoTarget::ImportExportedName { alias, import_from } => { + let symbol_name = alias.name.as_str(); + Some(definitions_for_imported_symbol( + model, + import_from, + symbol_name, + alias_resolution, + )) } GotoTarget::ImportModuleComponent { @@ -423,15 +435,12 @@ impl GotoTarget<'_> { } // Handle import aliases (offset within 'z' in "import x.y as z") - GotoTarget::ImportModuleAlias { alias } => { - if let Some(asname) = alias.asname.as_ref() - && alias_resolution == ImportAliasResolution::PreserveAliases - { - Some(definitions_for_name(model, asname.as_str(), asname.into())) - } else { - definitions_for_module(model, Some(alias.name.as_str()), 0) - } - } + GotoTarget::ImportModuleAlias { asname, .. } => Some(definitions_for_name( + model, + asname.as_str(), + AnyNodeRef::from(*asname), + alias_resolution, + )), // Handle keyword arguments in call expressions GotoTarget::KeywordArgument { @@ -454,12 +463,22 @@ impl GotoTarget<'_> { // because they're not expressions GotoTarget::PatternMatchRest(pattern_mapping) => { pattern_mapping.rest.as_ref().map(|name| { - definitions_for_name(model, name.as_str(), AnyNodeRef::Identifier(name)) + definitions_for_name( + model, + name.as_str(), + AnyNodeRef::Identifier(name), + alias_resolution, + ) }) } GotoTarget::PatternMatchAsName(pattern_as) => pattern_as.name.as_ref().map(|name| { - definitions_for_name(model, name.as_str(), AnyNodeRef::Identifier(name)) + definitions_for_name( + model, + name.as_str(), + AnyNodeRef::Identifier(name), + alias_resolution, + ) }), GotoTarget::PatternKeywordArgument(pattern_keyword) => { @@ -468,12 +487,18 @@ impl GotoTarget<'_> { model, name.as_str(), AnyNodeRef::Identifier(name), + alias_resolution, )) } GotoTarget::PatternMatchStarName(pattern_star) => { pattern_star.name.as_ref().map(|name| { - definitions_for_name(model, name.as_str(), AnyNodeRef::Identifier(name)) + definitions_for_name( + model, + name.as_str(), + AnyNodeRef::Identifier(name), + alias_resolution, + ) }) } @@ -481,9 +506,18 @@ impl GotoTarget<'_> { // // Prefer the function impl over the callable so that its docstrings win if defined. GotoTarget::Call { callable, call } => { - let mut definitions = definitions_for_callable(model, call); + let mut definitions = Vec::new(); + + // We prefer the specific overload for hover, go-to-def etc. However, + // `definitions_for_callable` always resolves import aliases. That's why we + // skip it in cases import alias resolution is turned of (rename, highlight references). + if alias_resolution == ImportAliasResolution::ResolveAliases { + definitions.extend(definitions_for_callable(model, call)); + } + let expr_definitions = - definitions_for_expression(model, *callable).unwrap_or_default(); + definitions_for_expression(model, *callable, alias_resolution) + .unwrap_or_default(); definitions.extend(expr_definitions); if definitions.is_empty() { @@ -517,7 +551,7 @@ impl GotoTarget<'_> { let subexpr = covering_node(subast.syntax().into(), *subrange) .node() .as_expr_ref()?; - definitions_for_expression(&submodel, subexpr) + definitions_for_expression(&submodel, subexpr, alias_resolution) } // nonlocal and global are essentially loads, but again they're statements, @@ -527,6 +561,7 @@ impl GotoTarget<'_> { model, identifier.as_str(), AnyNodeRef::Identifier(identifier), + alias_resolution, )) } @@ -537,6 +572,7 @@ impl GotoTarget<'_> { model, name.as_str(), AnyNodeRef::Identifier(name), + alias_resolution, )) } @@ -546,6 +582,7 @@ impl GotoTarget<'_> { model, name.as_str(), AnyNodeRef::Identifier(name), + alias_resolution, )) } @@ -555,6 +592,7 @@ impl GotoTarget<'_> { model, name.as_str(), AnyNodeRef::Identifier(name), + alias_resolution, )) } }; @@ -580,12 +618,9 @@ impl GotoTarget<'_> { GotoTarget::FunctionDef(function) => Some(Cow::Borrowed(function.name.as_str())), GotoTarget::ClassDef(class) => Some(Cow::Borrowed(class.name.as_str())), GotoTarget::Parameter(parameter) => Some(Cow::Borrowed(parameter.name.as_str())), - GotoTarget::ImportSymbolAlias { alias, .. } => { - if let Some(asname) = &alias.asname { - Some(Cow::Borrowed(asname.as_str())) - } else { - Some(Cow::Borrowed(alias.name.as_str())) - } + GotoTarget::ImportSymbolAlias { asname, .. } => Some(Cow::Borrowed(asname.as_str())), + GotoTarget::ImportExportedName { alias, .. } => { + Some(Cow::Borrowed(alias.name.as_str())) } GotoTarget::ImportModuleComponent { module_name, @@ -599,13 +634,7 @@ impl GotoTarget<'_> { Some(Cow::Borrowed(module_name)) } } - GotoTarget::ImportModuleAlias { alias } => { - if let Some(asname) = &alias.asname { - Some(Cow::Borrowed(asname.as_str())) - } else { - Some(Cow::Borrowed(alias.name.as_str())) - } - } + GotoTarget::ImportModuleAlias { asname, .. } => Some(Cow::Borrowed(asname.as_str())), GotoTarget::ExceptVariable(except) => { Some(Cow::Borrowed(except.name.as_ref()?.as_str())) } @@ -667,7 +696,7 @@ impl GotoTarget<'_> { // Is the offset within the alias name (asname) part? if let Some(asname) = &alias.asname { if asname.range.contains_inclusive(offset) { - return Some(GotoTarget::ImportModuleAlias { alias }); + return Some(GotoTarget::ImportModuleAlias { alias, asname }); } } @@ -699,21 +728,13 @@ impl GotoTarget<'_> { // Is the offset within the alias name (asname) part? if let Some(asname) = &alias.asname { if asname.range.contains_inclusive(offset) { - return Some(GotoTarget::ImportSymbolAlias { - alias, - range: asname.range, - import_from, - }); + return Some(GotoTarget::ImportSymbolAlias { alias, asname }); } } // Is the offset in the original name part? if alias.name.range.contains_inclusive(offset) { - return Some(GotoTarget::ImportSymbolAlias { - alias, - range: alias.name.range, - import_from, - }); + return Some(GotoTarget::ImportExportedName { alias, import_from }); } None @@ -893,12 +914,13 @@ impl Ranged for GotoTarget<'_> { GotoTarget::FunctionDef(function) => function.name.range, GotoTarget::ClassDef(class) => class.name.range, GotoTarget::Parameter(parameter) => parameter.name.range, - GotoTarget::ImportSymbolAlias { range, .. } => *range, + GotoTarget::ImportSymbolAlias { asname, .. } => asname.range, + Self::ImportExportedName { alias, .. } => alias.name.range, GotoTarget::ImportModuleComponent { component_range, .. } => *component_range, GotoTarget::StringAnnotationSubexpr { subrange, .. } => *subrange, - GotoTarget::ImportModuleAlias { alias } => alias.asname.as_ref().unwrap().range, + GotoTarget::ImportModuleAlias { asname, .. } => asname.range, GotoTarget::ExceptVariable(except) => except.name.as_ref().unwrap().range, GotoTarget::KeywordArgument { keyword, .. } => keyword.arg.as_ref().unwrap().range, GotoTarget::PatternMatchRest(rest) => rest.rest.as_ref().unwrap().range, @@ -955,12 +977,14 @@ fn convert_resolved_definitions_to_targets<'db>( fn definitions_for_expression<'db>( model: &SemanticModel<'db>, expression: ruff_python_ast::ExprRef<'_>, + alias_resolution: ImportAliasResolution, ) -> Option>> { match expression { ast::ExprRef::Name(name) => Some(definitions_for_name( model, name.id.as_str(), expression.into(), + alias_resolution, )), ast::ExprRef::Attribute(attribute) => Some(ty_python_semantic::definitions_for_attribute( model, attribute, diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index d759b1daed..5bec65a356 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -37,6 +37,38 @@ pub enum ReferencesMode { DocumentHighlights, } +impl ReferencesMode { + pub(super) fn to_import_alias_resolution(self) -> ImportAliasResolution { + match self { + // Resolve import aliases for find references: + // ```py + // from warnings import deprecated as my_deprecated + // + // @my_deprecated + // def foo + // ``` + // + // When finding references on `my_deprecated`, we want to find all usages of `deprecated` across the entire + // project. + Self::References | Self::ReferencesSkipDeclaration => { + ImportAliasResolution::ResolveAliases + } + // For rename, don't resolve import aliases. + // + // ```py + // from warnings import deprecated as my_deprecated + // + // @my_deprecated + // def foo + // ``` + // When renaming `my_deprecated`, only rename the alias, but not the original definition in `warnings`. + Self::Rename | Self::RenameMultiFile | Self::DocumentHighlights => { + ImportAliasResolution::PreserveAliases + } + } + } +} + /// Find all references to a symbol at the given position. /// Search for references across all files in the project. pub(crate) fn references( @@ -45,12 +77,9 @@ pub(crate) fn references( goto_target: &GotoTarget, mode: ReferencesMode, ) -> Option> { - // Get the definitions for the symbol at the cursor position - - // When finding references, do not resolve any local aliases. let model = SemanticModel::new(db, file); let target_definitions = goto_target - .get_definition_targets(&model, ImportAliasResolution::PreserveAliases)? + .get_definition_targets(&model, mode.to_import_alias_resolution())? .declaration_targets(db)?; // Extract the target text from the goto target for fast comparison @@ -318,7 +347,7 @@ impl LocalReferencesFinder<'_> { { // Get the definitions for this goto target if let Some(current_definitions) = goto_target - .get_definition_targets(self.model, ImportAliasResolution::PreserveAliases) + .get_definition_targets(self.model, self.mode.to_import_alias_resolution()) .and_then(|definitions| definitions.declaration_targets(self.model.db())) { // Check if any of the current definitions match our target definitions diff --git a/crates/ty_ide/src/rename.rs b/crates/ty_ide/src/rename.rs index a8e91ebdcd..fe51f06615 100644 --- a/crates/ty_ide/src/rename.rs +++ b/crates/ty_ide/src/rename.rs @@ -3,7 +3,7 @@ use crate::references::{ReferencesMode, references}; use crate::{Db, ReferenceTarget}; use ruff_db::files::File; use ruff_text_size::{Ranged, TextSize}; -use ty_python_semantic::{ImportAliasResolution, SemanticModel}; +use ty_python_semantic::SemanticModel; /// Returns the range of the symbol if it can be renamed, None if not. pub fn can_rename(db: &dyn Db, file: File, offset: TextSize) -> Option { @@ -24,26 +24,22 @@ pub fn can_rename(db: &dyn Db, file: File, offset: TextSize) -> Option main.py:3:20 + | + 2 | import warnings + 3 | import warnings as abc + | ^^^ + 4 | + 5 | x = abc + | --- + 6 | y = warnings + | + "); + } + + #[test] + fn import_alias_to_first_party_definition() { + let test = CursorTest::builder() + .source("lib.py", "def deprecated(): pass") + .source( + "main.py", + r#" + import lib as lib2 + + x = lib2 + "#, + ) + .build(); + + assert_snapshot!(test.rename("z"), @r" + info[rename]: Rename symbol (found 2 locations) + --> main.py:2:15 + | + 2 | import lib as lib2 + | ^^^^ + 3 | + 4 | x = lib2 + | ---- + | + "); + } + + #[test] + fn imported_first_party_definition() { + let test = CursorTest::builder() + .source("lib.py", "def deprecated(): pass") + .source( + "main.py", + r#" + from lib import deprecated + + x = deprecated + "#, + ) + .build(); + + assert_snapshot!(test.rename("z"), @r" + info[rename]: Rename symbol (found 3 locations) + --> main.py:2:17 + | + 2 | from lib import deprecated + | ^^^^^^^^^^ + 3 | + 4 | x = deprecated + | ---------- + | + ::: lib.py:1:5 + | + 1 | def deprecated(): pass + | ---------- + | + "); } - // TODO Should rename the alias #[test] fn import_alias_use() { let test = CursorTest::builder() @@ -1221,7 +1286,19 @@ result = func(10, y=20) ) .build(); - assert_snapshot!(test.rename("z"), @"Cannot rename"); + assert_snapshot!(test.rename("z"), @r" + info[rename]: Rename symbol (found 2 locations) + --> main.py:3:20 + | + 2 | import warnings + 3 | import warnings as abc + | ^^^ + 4 | + 5 | x = abc + | --- + 6 | y = warnings + | + "); } #[test] diff --git a/crates/ty_ide/src/semantic_tokens.rs b/crates/ty_ide/src/semantic_tokens.rs index 88e48d1470..5667d1506f 100644 --- a/crates/ty_ide/src/semantic_tokens.rs +++ b/crates/ty_ide/src/semantic_tokens.rs @@ -259,7 +259,11 @@ impl<'db> SemanticTokenVisitor<'db> { fn classify_name(&self, name: &ast::ExprName) -> (SemanticTokenType, SemanticTokenModifier) { // First try to classify the token based on its definition kind. - let definition = definition_for_name(self.model, name); + let definition = definition_for_name( + self.model, + name, + ty_python_semantic::ImportAliasResolution::ResolveAliases, + ); if let Some(definition) = definition { let name_str = name.id.as_str(); diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index 111edcabc5..a74cc82f7e 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use crate::FxIndexSet; use crate::place::builtins_module_scope; use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::DefinitionKind; @@ -24,8 +25,9 @@ use resolve_definition::{find_symbol_in_scope, resolve_definition}; pub fn definition_for_name<'db>( model: &SemanticModel<'db>, name: &ast::ExprName, + alias_resolution: ImportAliasResolution, ) -> Option> { - let definitions = definitions_for_name(model, name.id.as_str(), name.into()); + let definitions = definitions_for_name(model, name.id.as_str(), name.into(), alias_resolution); // Find the first valid definition and return its kind for declaration in definitions { @@ -43,6 +45,7 @@ pub fn definitions_for_name<'db>( model: &SemanticModel<'db>, name_str: &str, node: AnyNodeRef<'_>, + alias_resolution: ImportAliasResolution, ) -> Vec> { let db = model.db(); let file = model.file(); @@ -53,7 +56,7 @@ pub fn definitions_for_name<'db>( return vec![]; }; - let mut all_definitions = Vec::new(); + let mut all_definitions = FxIndexSet::default(); // Search through the scope hierarchy: start from the current scope and // traverse up through parent scopes to find definitions @@ -89,13 +92,13 @@ pub fn definitions_for_name<'db>( for binding in global_bindings { if let Some(def) = binding.binding.definition() { - all_definitions.push(def); + all_definitions.insert(def); } } for declaration in global_declarations { if let Some(def) = declaration.declaration.definition() { - all_definitions.push(def); + all_definitions.insert(def); } } } @@ -116,13 +119,13 @@ pub fn definitions_for_name<'db>( for binding in bindings { if let Some(def) = binding.binding.definition() { - all_definitions.push(def); + all_definitions.insert(def); } } for declaration in declarations { if let Some(def) = declaration.declaration.definition() { - all_definitions.push(def); + all_definitions.insert(def); } } @@ -136,21 +139,14 @@ pub fn definitions_for_name<'db>( let mut resolved_definitions = Vec::new(); for definition in &all_definitions { - let resolved = resolve_definition( - db, - *definition, - Some(name_str), - ImportAliasResolution::ResolveAliases, - ); + let resolved = resolve_definition(db, *definition, Some(name_str), alias_resolution); resolved_definitions.extend(resolved); } // If we didn't find any definitions in scopes, fallback to builtins - if resolved_definitions.is_empty() { - let Some(builtins_scope) = builtins_module_scope(db) else { - return resolved_definitions; - }; - + if resolved_definitions.is_empty() + && let Some(builtins_scope) = builtins_module_scope(db) + { // Special cases for `float` and `complex` in type annotation positions. // We don't know whether we're in a type annotation position, so we'll just ask `Name`'s type, // which resolves to `int | float` or `int | float | complex` if `float` or `complex` is used in @@ -932,6 +928,12 @@ mod resolve_definition { let module = parsed_module(db, file).load(db); let alias = import_def.alias(&module); + if alias.asname.is_some() + && alias_resolution == ImportAliasResolution::PreserveAliases + { + return vec![ResolvedDefinition::Definition(definition)]; + } + // Get the full module name being imported let Some(module_name) = ModuleName::new(&alias.name) else { return Vec::new(); // Invalid module name, return empty list @@ -955,7 +957,13 @@ mod resolve_definition { let file = definition.file(db); let module = parsed_module(db, file).load(db); let import_node = import_from_def.import(&module); - let name = &import_from_def.alias(&module).name; + let alias = import_from_def.alias(&module); + + if alias.asname.is_some() + && alias_resolution == ImportAliasResolution::PreserveAliases + { + return vec![ResolvedDefinition::Definition(definition)]; + } // For `ImportFrom`, we need to resolve the original imported symbol name // (alias.name), not the local alias (symbol_name) @@ -963,7 +971,7 @@ mod resolve_definition { db, file, import_node, - name, + &alias.name, visited, alias_resolution, )