diff --git a/crates/ty_ide/src/importer.rs b/crates/ty_ide/src/importer.rs index 5ff46a1ae1..1dff46bcaf 100644 --- a/crates/ty_ide/src/importer.rs +++ b/crates/ty_ide/src/importer.rs @@ -145,7 +145,7 @@ impl<'a> Importer<'a> { members: &MembersInScope, ) -> ImportAction { let request = request.avoid_conflicts(self.db, self.file, members); - let mut symbol_text: Box = request.member.into(); + let mut symbol_text: Box = request.member.unwrap_or(request.module).into(); let Some(response) = self.find(&request, members.at) else { let insertion = if let Some(future) = self.find_last_future_import(members.at) { Insertion::end_of_statement(future.stmt, self.source, self.stylist) @@ -157,14 +157,27 @@ impl<'a> Importer<'a> { Insertion::start_of_file(self.parsed.suite(), self.source, self.stylist, range) }; let import = insertion.into_edit(&request.to_string()); - if matches!(request.style, ImportStyle::Import) { - symbol_text = format!("{}.{}", request.module, request.member).into(); + if let Some(member) = request.member + && matches!(request.style, ImportStyle::Import) + { + symbol_text = format!("{}.{}", request.module, member).into(); } return ImportAction { import: Some(import), symbol_text, }; }; + + // When we just have a request to import a module (and not + // any members from that module), then the only way we can be + // here is if we found a pre-existing import that definitively + // satisfies the request. So we're done. + let Some(member) = request.member else { + return ImportAction { + import: None, + symbol_text, + }; + }; match response.kind { ImportResponseKind::Unqualified { ast, alias } => { let member = alias.asname.as_ref().unwrap_or(&alias.name).as_str(); @@ -189,13 +202,10 @@ impl<'a> Importer<'a> { let import = if let Some(insertion) = Insertion::existing_import(response.import.stmt, self.tokens) { - insertion.into_edit(request.member) + insertion.into_edit(member) } else { Insertion::end_of_statement(response.import.stmt, self.source, self.stylist) - .into_edit(&format!( - "from {} import {}", - request.module, request.member - )) + .into_edit(&format!("from {} import {member}", request.module)) }; ImportAction { import: Some(import), @@ -481,6 +491,17 @@ impl<'ast> AstImportKind<'ast> { Some(ImportResponseKind::Qualified { ast, alias }) } AstImportKind::ImportFrom(ast) => { + // If the request is for a module itself, then we + // assume that it can never be satisfies by a + // `from ... import ...` statement. For example, a + // `request for collections.abc` needs an + // `import collections.abc`. Now, there could be a + // `from collections import abc`, and we could + // plausibly consider that a match and return a + // symbol text of `abc`. But it's not clear if that's + // the right choice or not. + let member = request.member?; + if request.force_style && !matches!(request.style, ImportStyle::ImportFrom) { return None; } @@ -492,9 +513,7 @@ impl<'ast> AstImportKind<'ast> { let kind = ast .names .iter() - .find(|alias| { - alias.name.as_str() == "*" || alias.name.as_str() == request.member - }) + .find(|alias| alias.name.as_str() == "*" || alias.name.as_str() == member) .map(|alias| ImportResponseKind::Unqualified { ast, alias }) .unwrap_or_else(|| ImportResponseKind::Partial(ast)); Some(kind) @@ -510,7 +529,10 @@ pub(crate) struct ImportRequest<'a> { /// `foo`, in `from foo import bar`). module: &'a str, /// The member to import (e.g., `bar`, in `from foo import bar`). - member: &'a str, + /// + /// When `member` is absent, then this request reflects an import + /// of the module itself. i.e., `import module`. + member: Option<&'a str>, /// The preferred style to use when importing the symbol (e.g., /// `import foo` or `from foo import bar`). /// @@ -532,7 +554,7 @@ impl<'a> ImportRequest<'a> { pub(crate) fn import(module: &'a str, member: &'a str) -> Self { Self { module, - member, + member: Some(member), style: ImportStyle::Import, force_style: false, } @@ -545,12 +567,26 @@ impl<'a> ImportRequest<'a> { pub(crate) fn import_from(module: &'a str, member: &'a str) -> Self { Self { module, - member, + member: Some(member), style: ImportStyle::ImportFrom, force_style: false, } } + /// Create a new [`ImportRequest`] for bringing the given module + /// into scope. + /// + /// This is for just importing the module itself, always via an + /// `import` statement. + pub(crate) fn module(module: &'a str) -> Self { + Self { + module, + member: None, + style: ImportStyle::Import, + force_style: false, + } + } + /// Causes this request to become a command. This will force the /// requested import style, even if another style would be more /// appropriate generally. @@ -565,7 +601,13 @@ impl<'a> ImportRequest<'a> { /// of an import conflict are minimized (although not always reduced /// to zero). fn avoid_conflicts(self, db: &dyn Db, importing_file: File, members: &MembersInScope) -> Self { - match (members.map.get(self.module), members.map.get(self.member)) { + let Some(member) = self.member else { + return Self { + style: ImportStyle::Import, + ..self + }; + }; + match (members.map.get(self.module), members.map.get(member)) { // Neither symbol exists, so we can just proceed as // normal. (None, None) => self, @@ -630,7 +672,10 @@ impl std::fmt::Display for ImportRequest<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self.style { ImportStyle::Import => write!(f, "import {}", self.module), - ImportStyle::ImportFrom => write!(f, "from {} import {}", self.module, self.member), + ImportStyle::ImportFrom => match self.member { + None => write!(f, "import {}", self.module), + Some(member) => write!(f, "from {} import {member}", self.module), + }, } } } @@ -843,6 +888,10 @@ mod tests { self.add(ImportRequest::import_from(module, member)) } + fn module(&self, module: &str) -> String { + self.add(ImportRequest::module(module)) + } + fn add(&self, request: ImportRequest<'_>) -> String { let node = covering_node( self.cursor.parsed.syntax().into(), @@ -2156,4 +2205,73 @@ except ImportError: (bar.MAGIC) "); } + + #[test] + fn import_module_blank() { + let test = cursor_test( + "\ + + ", + ); + assert_snapshot!( + test.module("collections"), @r" + import collections + collections + "); + } + + #[test] + fn import_module_exists() { + let test = cursor_test( + "\ +import collections + + ", + ); + assert_snapshot!( + test.module("collections"), @r" + import collections + collections + "); + } + + #[test] + fn import_module_from_exists() { + let test = cursor_test( + "\ +from collections import defaultdict + + ", + ); + assert_snapshot!( + test.module("collections"), @r" + import collections + from collections import defaultdict + collections + "); + } + + // This test is working as intended. That is, + // `abc` is already in scope, so requesting an + // import for `collections.abc` could feasibly + // reuse the import and rewrite the symbol text + // to just `abc`. But for now it seems better + // to respect what has been written and add the + // `import collections.abc`. This behavior could + // plausibly be changed. + #[test] + fn import_module_from_via_member_exists() { + let test = cursor_test( + "\ +from collections import abc + + ", + ); + assert_snapshot!( + test.module("collections.abc"), @r" + import collections.abc + from collections import abc + collections.abc + "); + } }