diff --git a/crates/ty_completion_eval/completion-evaluation-tasks.csv b/crates/ty_completion_eval/completion-evaluation-tasks.csv index e7398195f5..adcee83d42 100644 --- a/crates/ty_completion_eval/completion-evaluation-tasks.csv +++ b/crates/ty_completion_eval/completion-evaluation-tasks.csv @@ -4,6 +4,7 @@ auto-import-includes-modules,main.py,1,2 auto-import-includes-modules,main.py,2,1 auto-import-skips-current-module,main.py,0,1 class-arg-completion,main.py,0,1 +class-arg-completion,main.py,1,1 exact-over-fuzzy,main.py,0,1 fstring-completions,main.py,0,1 higher-level-symbols-preferred,main.py,0, diff --git a/crates/ty_completion_eval/truth/class-arg-completion/main.py b/crates/ty_completion_eval/truth/class-arg-completion/main.py index ab3cb7a74b..9183578c47 100644 --- a/crates/ty_completion_eval/truth/class-arg-completion/main.py +++ b/crates/ty_completion_eval/truth/class-arg-completion/main.py @@ -1 +1,4 @@ +# Prefer keyword arguments and inheritable classes +# (E.g. `NotImplementedError` over `NotImplemented`). class Foo(m) +class Foo(NotImplem) diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index 14604c9c66..217900bda1 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -13,7 +13,7 @@ use ruff_text_size::{Ranged, TextRange, TextSize}; use rustc_hash::FxHashSet; use ty_module_resolver::{KnownModule, Module, ModuleName}; use ty_python_semantic::HasType; -use ty_python_semantic::types::UnionType; +use ty_python_semantic::types::{SpecialFormType, UnionType}; use ty_python_semantic::{ Completion as SemanticCompletion, NameKind, SemanticModel, types::{CycleDetector, KnownClass, Type}, @@ -378,20 +378,30 @@ impl<'db> CompletionBuilder<'db> { ctx: &CollectionContext<'db>, query: &UserQuery, ) -> Completion<'db> { - // Tags completions with context-specific if they are - // known to be usable in a `raise` context and we have - // determined a raisable type `raisable_ty`. - // - // It's possible that some completions are usable in a `raise` - // but aren't marked here. That is, false negatives are - // possible but false positives are not. - if let Some(raisable_ty) = ctx.raisable_ty { - if let Some(ty) = self.ty { - self.is_context_specific |= ty.is_assignable_to(db, raisable_ty); - } - } if let Some(ty) = self.ty { self.is_type_check_only = ty.is_type_check_only(db); + // Tags completions with context-specific if they are + // known to be usable in a `raise` context and we have + // determined a raisable type `raisable_ty`. + // + // It's possible that some completions are usable in a `raise` + // but aren't marked here. That is, false negatives are + // possible but false positives are not. + if let Some(raisable_ty) = ctx.raisable_ty { + self.is_context_specific |= ty.is_assignable_to(db, raisable_ty); + } + if ctx.is_in_class_def { + self.is_context_specific |= ty.is_class_literal() + || matches!( + ty, + Type::SpecialForm( + SpecialFormType::Protocol + | SpecialFormType::Generic + | SpecialFormType::TypedDict + | SpecialFormType::NamedTuple + ) + ); + } } let kind = self .kind @@ -596,6 +606,7 @@ impl<'m> Context<'m> { ) }), is_raising_exception, + is_in_class_def: self.cursor.is_in_class_def(), valid_keywords: self.cursor.valid_keywords(), } } @@ -628,6 +639,8 @@ struct ContextCursor<'m> { range: TextRange, /// The tokens that appear before the cursor. tokens_before: &'m [Token], + /// The covering node based on `parsed` and `range`. + covering_node: CoveringNode<'m>, } impl<'m> ContextCursor<'m> { @@ -639,13 +652,16 @@ impl<'m> ContextCursor<'m> { ) -> ContextCursor<'m> { let tokens_before = tokens_start_before(parsed.tokens(), offset); let Some(range) = ContextCursor::find_typed_text_range(tokens_before, offset) else { + let range = TextRange::empty(offset); + let covering_node = covering_node(parsed.syntax().into(), range); return ContextCursor { parsed, source, typed: None, offset, - range: TextRange::empty(offset), + range, tokens_before, + covering_node, }; }; @@ -654,6 +670,8 @@ impl<'m> ContextCursor<'m> { !text.is_empty(), "expected typed text, when found, to be non-empty" ); + + let covering_node = covering_node(parsed.syntax().into(), range); ContextCursor { parsed, source, @@ -661,6 +679,7 @@ impl<'m> ContextCursor<'m> { offset, range, tokens_before, + covering_node, } } @@ -761,8 +780,7 @@ impl<'m> ContextCursor<'m> { /// Returns true when the cursor sits on a binding statement. /// E.g. naming a parameter, type parameter, or `for` ). fn is_in_variable_binding(&self) -> bool { - let covering = self.covering_node(self.range); - covering.ancestors().any(|node| match node { + self.covering_node.ancestors().any(|node| match node { ast::AnyNodeRef::Parameter(param) => param.name.range.contains_range(self.range), ast::AnyNodeRef::TypeParamTypeVar(type_param) => { type_param.name.range.contains_range(self.range) @@ -816,17 +834,35 @@ impl<'m> ContextCursor<'m> { false } + /// Returns true when the curser is within the + /// arguments node of a class definition. + /// + /// E.g. `class Foo(Bar)` + fn is_in_class_def(&self) -> bool { + for node in self.covering_node.ancestors() { + if let ast::AnyNodeRef::StmtClassDef(class_def) = node { + return class_def + .arguments + .as_ref() + .is_some_and(|args| args.range.contains_range(self.range)); + } + if node.is_statement() { + return false; + } + } + false + } + /// Returns a set of keywords that are valid at /// the current cursor position. /// /// Returns None if no context-based exclusions can /// be identified. Meaning that all keywords are valid. fn valid_keywords(&self) -> Option> { - let covering_node = self.covering_node(self.range); - // Check if the cursor is within the naming // part of a decorator node. - if covering_node + if self + .covering_node .ancestors() // We bail if we're specifying arguments as we don't // want to suppress suggestions there. @@ -837,7 +873,7 @@ impl<'m> ContextCursor<'m> { { return Some(FxHashSet::from_iter(["lambda"])); } - covering_node.ancestors().find_map(|node| { + self.covering_node.ancestors().find_map(|node| { self.is_in_for_statement_iterable(node) .then(|| FxHashSet::from_iter(["yield", "lambda", "await"])) .or_else(|| { @@ -1008,6 +1044,8 @@ struct CollectionContext<'db> { raisable_ty: Option>, /// Whether we're in a `raise ` context or not. is_raising_exception: bool, + /// Whether we're in a class definition context or not. + is_in_class_def: bool, /// When set, the context dictates that only *these* keywords /// are acceptable in this context. valid_keywords: Option>, @@ -1271,7 +1309,7 @@ fn add_argument_completions<'db>( cursor: &ContextCursor<'_>, completions: &mut Completions<'db>, ) { - for node in cursor.covering_node(cursor.range).ancestors() { + for node in cursor.covering_node.ancestors() { match node { ast::AnyNodeRef::ExprCall(call) => { if call.arguments.range().contains_range(cursor.range) { @@ -1347,7 +1385,7 @@ fn add_function_arg_completions<'db>( ) { debug_assert!( cursor - .covering_node(cursor.range) + .covering_node .ancestors() .take_while(|node| !node.is_statement()) .any(|node| node.is_arguments()), @@ -1394,9 +1432,8 @@ fn add_function_arg_completions<'db>( /// If the parent node is not an arguments node, the return value /// is an empty Vec. fn detect_set_function_args<'m>(cursor: &ContextCursor<'m>) -> FxHashSet<&'m str> { - let range = TextRange::empty(cursor.offset); cursor - .covering_node(range) + .covering_node .parent() .and_then(|node| match node { ast::AnyNodeRef::Arguments(args) => Some(args), @@ -1679,13 +1716,9 @@ impl<'t> CompletionTargetTokens<'t> { let node = cursor.covering_node(token.range()).node(); Some(CompletionTargetAst::Scoped(ScopedTarget { node })) } - CompletionTargetTokens::Unknown => { - let range = TextRange::empty(cursor.offset); - let covering_node = cursor.covering_node(range); - Some(CompletionTargetAst::Scoped(ScopedTarget { - node: covering_node.node(), - })) - } + CompletionTargetTokens::Unknown => Some(CompletionTargetAst::Scoped(ScopedTarget { + node: cursor.covering_node.node(), + })), } } } @@ -3344,9 +3377,9 @@ class Foo(): ); assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @" - metaclass= Bar Foo + metaclass= "); } @@ -3362,9 +3395,9 @@ class Bar: ... ); assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @" - metaclass= Bar Foo + metaclass= "); } @@ -3380,9 +3413,9 @@ class Bar: ... ); assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @" - metaclass= Bar Foo + metaclass= "); } @@ -3396,9 +3429,9 @@ class Foo(", ); assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @" - metaclass= Bar Foo + metaclass= "); }