diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index ae0f5d5900..eeef8324d0 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -26,6 +26,7 @@ pub fn completion(db: &dyn Db, file: File, offset: TextSize) -> Vec let model = ty_python_semantic::SemanticModel::new(db.upcast(), file); let mut completions = match target { CompletionTargetAst::ObjectDot { expr } => model.attribute_completions(expr), + CompletionTargetAst::ImportFrom { import, name } => model.import_completions(import, name), CompletionTargetAst::Scoped { node } => model.scoped_completions(node), }; completions.sort_by(|name1, name2| compare_suggestions(name1, name2)); @@ -62,6 +63,12 @@ enum CompletionTargetTokens<'t> { #[expect(dead_code)] attribute: Option<&'t Token>, }, + /// A `from module import attribute` token form was found, where + /// `attribute` may be empty. + ImportFrom { + /// The module being imported from. + module: &'t Token, + }, /// A token was found under the cursor, but it didn't /// match any of our anticipated token patterns. Generic { token: &'t Token }, @@ -102,6 +109,8 @@ impl<'t> CompletionTargetTokens<'t> { object, attribute: Some(attribute), } + } else if let Some(module) = import_from_tokens(before) { + CompletionTargetTokens::ImportFrom { module } } else if let Some([_]) = token_suffix_by_kinds(before, [TokenKind::Float]) { // If we're writing a `float`, then we should // specifically not offer completions. This wouldn't @@ -153,6 +162,15 @@ impl<'t> CompletionTargetTokens<'t> { _ => None, } } + CompletionTargetTokens::ImportFrom { module, .. } => { + let covering_node = covering_node(parsed.syntax().into(), module.range()) + .find_first(|node| node.is_stmt_import_from()) + .ok()?; + let ast::AnyNodeRef::StmtImportFrom(import) = covering_node.node() else { + return None; + }; + Some(CompletionTargetAst::ImportFrom { import, name: None }) + } CompletionTargetTokens::Generic { token } => { let covering_node = covering_node(parsed.syntax().into(), token.range()); Some(CompletionTargetAst::Scoped { @@ -176,6 +194,15 @@ enum CompletionTargetAst<'t> { /// A `object.attribute` scenario, where we want to /// list attributes on `object` for completions. ObjectDot { expr: &'t ast::ExprAttribute }, + /// A `from module import attribute` scenario, where we want to + /// list attributes on `module` for completions. + ImportFrom { + /// The import statement. + import: &'t ast::StmtImportFrom, + /// An index into `import.names` if relevant. When this is + /// set, the index is guaranteed to be valid. + name: Option, + }, /// A scoped scenario, where we want to list all items available in /// the most narrow scope containing the giving AST node. Scoped { node: ast::AnyNodeRef<'t> }, @@ -205,6 +232,97 @@ fn token_suffix_by_kinds( })) } +/// Looks for the start of a `from module import ` statement. +/// +/// If found, one arbitrary token forming `module` is returned. +fn import_from_tokens(tokens: &[Token]) -> Option<&Token> { + use TokenKind as TK; + + /// The number of tokens we're willing to consume backwards from + /// the cursor's position until we give up looking for a `from + /// module import ` pattern. The state machine below has + /// lots of opportunities to bail way earlier than this, but if + /// there's, e.g., a long list of name tokens for something that + /// isn't an import, then we could end up doing a lot of wasted + /// work here. Probably humans aren't often working with single + /// import statements over 1,000 tokens long. + /// + /// The other thing to consider here is that, by the time we get to + /// this point, ty has already done some work proportional to the + /// length of `tokens` anyway. The unit of work we do below is very + /// small. + const LIMIT: usize = 1_000; + + /// A state used to "parse" the tokens preceding the user's cursor, + /// in reverse, to detect a "from import" statement. + enum S { + Start, + Names, + Module, + } + + let mut state = S::Start; + let mut module_token: Option<&Token> = None; + // Move backward through the tokens until we get to + // the `from` token. + for token in tokens.iter().rev().take(LIMIT) { + state = match (state, token.kind()) { + // It's okay to pop off a newline token here initially, + // since it may occur when the name being imported is + // empty. + (S::Start, TK::Newline) => S::Names, + // Munch through tokens that can make up an alias. + // N.B. We could also consider taking any token here + // *except* some limited set of tokens (like `Newline`). + // That might work well if it turns out that listing + // all possible allowable tokens is too brittle. + ( + S::Start | S::Names, + TK::Name + | TK::Comma + | TK::As + | TK::Case + | TK::Match + | TK::Type + | TK::Star + | TK::Lpar + | TK::Rpar + | TK::NonLogicalNewline + // It's not totally clear the conditions under + // which this occurs (I haven't read our tokenizer), + // but it appears in code like this, where this is + // the entire file contents: + // + // from sys import ( + // abiflags, + // + // + // It seems harmless to just allow this "unknown" + // token here to make the above work. + | TK::Unknown, + ) => S::Names, + (S::Start | S::Names, TK::Import) => S::Module, + // Munch through tokens that can make up a module. + ( + S::Module, + TK::Name | TK::Dot | TK::Ellipsis | TK::Case | TK::Match | TK::Type | TK::Unknown, + ) => { + // It's okay if there are multiple module + // tokens here. Just taking the last one + // (which is the one appearing first in + // the source code) is fine. We only need + // this to find the corresponding AST node, + // so any of the tokens should work fine. + module_token = Some(token); + S::Module + } + (S::Module, TK::From) => return module_token, + _ => return None, + }; + } + None +} + /// Order completions lexicographically, with these exceptions: /// /// 1) A `_[^_]` prefix sorts last and @@ -1850,6 +1968,205 @@ def test_point(p2: Point): test.assert_completions_include("orthogonal_direction"); } + #[test] + fn from_import1() { + let test = cursor_test( + "\ +from sys import +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import2() { + let test = cursor_test( + "\ +from sys import abiflags, +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import3() { + let test = cursor_test( + "\ +from sys import , abiflags +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import4() { + let test = cursor_test( + "\ +from sys import abiflags, \ + +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import5() { + let test = cursor_test( + "\ +from sys import abiflags as foo, +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import6() { + let test = cursor_test( + "\ +from sys import abiflags as foo, g +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import7() { + let test = cursor_test( + "\ +from sys import abiflags as foo, \ + +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import8() { + let test = cursor_test( + "\ +from sys import abiflags as foo, \ + g +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import9() { + let test = cursor_test( + "\ +from sys import ( + abiflags, + +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import10() { + let test = cursor_test( + "\ +from sys import ( + abiflags, + +) +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import11() { + let test = cursor_test( + "\ +from sys import ( + +) +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import_unknown_in_module() { + let test = cursor_test( + "\ +foo = 1 +from ? import +", + ); + assert_snapshot!(test.completions(), @r""); + } + + #[test] + fn from_import_unknown_in_import_names1() { + let test = cursor_test( + "\ +from sys import ?, +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import_unknown_in_import_names2() { + let test = cursor_test( + "\ +from sys import ??, +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn from_import_unknown_in_import_names3() { + let test = cursor_test( + "\ +from sys import ??, , ?? +", + ); + test.assert_completions_include("getsizeof"); + } + + #[test] + fn import_submodule_not_attribute1() { + let test = cursor_test( + "\ +import importlib +importlib. +", + ); + test.assert_completions_do_not_include("resources"); + } + + #[test] + fn import_submodule_not_attribute2() { + let test = cursor_test( + "\ +import importlib.resources +importlib. +", + ); + // TODO: This is wrong. Completions should include + // `resources` here. + test.assert_completions_do_not_include("resources"); + } + + #[test] + fn import_submodule_not_attribute3() { + let test = cursor_test( + "\ +import importlib +import importlib.resources +importlib. +", + ); + // TODO: This is wrong. Completions should include + // `resources` here. + test.assert_completions_do_not_include("resources"); + } + #[test] fn regression_test_issue_642() { // Regression test for https://github.com/astral-sh/ty/issues/642 diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index 9237e75ee2..d6795e6d92 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -41,6 +41,31 @@ impl<'db> SemanticModel<'db> { resolve_module(self.db, module_name) } + /// Returns completions for symbols available in a `from module import ` context. + pub fn import_completions( + &self, + import: &ast::StmtImportFrom, + _name: Option, + ) -> Vec { + let module_name = match ModuleName::from_import_statement(self.db, self.file, import) { + Ok(module_name) => module_name, + Err(err) => { + tracing::debug!( + "Could not extract module name from `{module:?}` with level {level}: {err:?}", + module = import.module, + level = import.level, + ); + return vec![]; + } + }; + let Some(module) = resolve_module(self.db, &module_name) else { + tracing::debug!("Could not resolve module from `{module_name:?}`"); + return vec![]; + }; + let ty = Type::module_literal(self.db, self.file, &module); + crate::types::all_members(self.db, ty).into_iter().collect() + } + /// Returns completions for symbols available in a `object.` context. pub fn attribute_completions(&self, node: &ast::ExprAttribute) -> Vec { let ty = node.value.inferred_type(self);