[ty] Better class def completions (#22571)

## Summary

Prefer completions of type `ClassLiteral` inside class-def context.

Fixes https://github.com/astral-sh/ty/issues/1776

## Test Plan
New completion-eval test that has incorrect ranking on main but correct
after this patch.
This commit is contained in:
RasmusNygren
2026-01-15 14:31:04 +01:00
committed by GitHub
parent 9c67b2acd9
commit 2a29ce3e41
3 changed files with 72 additions and 35 deletions

View File

@@ -4,6 +4,7 @@ auto-import-includes-modules,main.py,1,2
auto-import-includes-modules,main.py,2,1
auto-import-skips-current-module,main.py,0,1
class-arg-completion,main.py,0,1
class-arg-completion,main.py,1,1
exact-over-fuzzy,main.py,0,1
fstring-completions,main.py,0,1
higher-level-symbols-preferred,main.py,0,
1 name file index rank
4 auto-import-includes-modules main.py 2 1
5 auto-import-skips-current-module main.py 0 1
6 class-arg-completion main.py 0 1
7 class-arg-completion main.py 1 1
8 exact-over-fuzzy main.py 0 1
9 fstring-completions main.py 0 1
10 higher-level-symbols-preferred main.py 0

View File

@@ -1 +1,4 @@
# Prefer keyword arguments and inheritable classes
# (E.g. `NotImplementedError` over `NotImplemented`).
class Foo(m<CURSOR: metaclass>)
class Foo(NotImplem<CURSOR: NotImplementedError>)

View File

@@ -13,7 +13,7 @@ use ruff_text_size::{Ranged, TextRange, TextSize};
use rustc_hash::FxHashSet;
use ty_module_resolver::{KnownModule, Module, ModuleName};
use ty_python_semantic::HasType;
use ty_python_semantic::types::UnionType;
use ty_python_semantic::types::{SpecialFormType, UnionType};
use ty_python_semantic::{
Completion as SemanticCompletion, NameKind, SemanticModel,
types::{CycleDetector, KnownClass, Type},
@@ -378,20 +378,30 @@ impl<'db> CompletionBuilder<'db> {
ctx: &CollectionContext<'db>,
query: &UserQuery,
) -> Completion<'db> {
// Tags completions with context-specific if they are
// known to be usable in a `raise` context and we have
// determined a raisable type `raisable_ty`.
//
// It's possible that some completions are usable in a `raise`
// but aren't marked here. That is, false negatives are
// possible but false positives are not.
if let Some(raisable_ty) = ctx.raisable_ty {
if let Some(ty) = self.ty {
self.is_context_specific |= ty.is_assignable_to(db, raisable_ty);
}
}
if let Some(ty) = self.ty {
self.is_type_check_only = ty.is_type_check_only(db);
// Tags completions with context-specific if they are
// known to be usable in a `raise` context and we have
// determined a raisable type `raisable_ty`.
//
// It's possible that some completions are usable in a `raise`
// but aren't marked here. That is, false negatives are
// possible but false positives are not.
if let Some(raisable_ty) = ctx.raisable_ty {
self.is_context_specific |= ty.is_assignable_to(db, raisable_ty);
}
if ctx.is_in_class_def {
self.is_context_specific |= ty.is_class_literal()
|| matches!(
ty,
Type::SpecialForm(
SpecialFormType::Protocol
| SpecialFormType::Generic
| SpecialFormType::TypedDict
| SpecialFormType::NamedTuple
)
);
}
}
let kind = self
.kind
@@ -596,6 +606,7 @@ impl<'m> Context<'m> {
)
}),
is_raising_exception,
is_in_class_def: self.cursor.is_in_class_def(),
valid_keywords: self.cursor.valid_keywords(),
}
}
@@ -628,6 +639,8 @@ struct ContextCursor<'m> {
range: TextRange,
/// The tokens that appear before the cursor.
tokens_before: &'m [Token],
/// The covering node based on `parsed` and `range`.
covering_node: CoveringNode<'m>,
}
impl<'m> ContextCursor<'m> {
@@ -639,13 +652,16 @@ impl<'m> ContextCursor<'m> {
) -> ContextCursor<'m> {
let tokens_before = tokens_start_before(parsed.tokens(), offset);
let Some(range) = ContextCursor::find_typed_text_range(tokens_before, offset) else {
let range = TextRange::empty(offset);
let covering_node = covering_node(parsed.syntax().into(), range);
return ContextCursor {
parsed,
source,
typed: None,
offset,
range: TextRange::empty(offset),
range,
tokens_before,
covering_node,
};
};
@@ -654,6 +670,8 @@ impl<'m> ContextCursor<'m> {
!text.is_empty(),
"expected typed text, when found, to be non-empty"
);
let covering_node = covering_node(parsed.syntax().into(), range);
ContextCursor {
parsed,
source,
@@ -661,6 +679,7 @@ impl<'m> ContextCursor<'m> {
offset,
range,
tokens_before,
covering_node,
}
}
@@ -761,8 +780,7 @@ impl<'m> ContextCursor<'m> {
/// Returns true when the cursor sits on a binding statement.
/// E.g. naming a parameter, type parameter, or `for` <name>).
fn is_in_variable_binding(&self) -> bool {
let covering = self.covering_node(self.range);
covering.ancestors().any(|node| match node {
self.covering_node.ancestors().any(|node| match node {
ast::AnyNodeRef::Parameter(param) => param.name.range.contains_range(self.range),
ast::AnyNodeRef::TypeParamTypeVar(type_param) => {
type_param.name.range.contains_range(self.range)
@@ -816,17 +834,35 @@ impl<'m> ContextCursor<'m> {
false
}
/// Returns true when the curser is within the
/// arguments node of a class definition.
///
/// E.g. `class Foo(Bar<CURSOR>)`
fn is_in_class_def(&self) -> bool {
for node in self.covering_node.ancestors() {
if let ast::AnyNodeRef::StmtClassDef(class_def) = node {
return class_def
.arguments
.as_ref()
.is_some_and(|args| args.range.contains_range(self.range));
}
if node.is_statement() {
return false;
}
}
false
}
/// Returns a set of keywords that are valid at
/// the current cursor position.
///
/// Returns None if no context-based exclusions can
/// be identified. Meaning that all keywords are valid.
fn valid_keywords(&self) -> Option<FxHashSet<&'static str>> {
let covering_node = self.covering_node(self.range);
// Check if the cursor is within the naming
// part of a decorator node.
if covering_node
if self
.covering_node
.ancestors()
// We bail if we're specifying arguments as we don't
// want to suppress suggestions there.
@@ -837,7 +873,7 @@ impl<'m> ContextCursor<'m> {
{
return Some(FxHashSet::from_iter(["lambda"]));
}
covering_node.ancestors().find_map(|node| {
self.covering_node.ancestors().find_map(|node| {
self.is_in_for_statement_iterable(node)
.then(|| FxHashSet::from_iter(["yield", "lambda", "await"]))
.or_else(|| {
@@ -1008,6 +1044,8 @@ struct CollectionContext<'db> {
raisable_ty: Option<Type<'db>>,
/// Whether we're in a `raise <EXPR>` context or not.
is_raising_exception: bool,
/// Whether we're in a class definition context or not.
is_in_class_def: bool,
/// When set, the context dictates that only *these* keywords
/// are acceptable in this context.
valid_keywords: Option<FxHashSet<&'static str>>,
@@ -1271,7 +1309,7 @@ fn add_argument_completions<'db>(
cursor: &ContextCursor<'_>,
completions: &mut Completions<'db>,
) {
for node in cursor.covering_node(cursor.range).ancestors() {
for node in cursor.covering_node.ancestors() {
match node {
ast::AnyNodeRef::ExprCall(call) => {
if call.arguments.range().contains_range(cursor.range) {
@@ -1347,7 +1385,7 @@ fn add_function_arg_completions<'db>(
) {
debug_assert!(
cursor
.covering_node(cursor.range)
.covering_node
.ancestors()
.take_while(|node| !node.is_statement())
.any(|node| node.is_arguments()),
@@ -1394,9 +1432,8 @@ fn add_function_arg_completions<'db>(
/// If the parent node is not an arguments node, the return value
/// is an empty Vec.
fn detect_set_function_args<'m>(cursor: &ContextCursor<'m>) -> FxHashSet<&'m str> {
let range = TextRange::empty(cursor.offset);
cursor
.covering_node(range)
.covering_node
.parent()
.and_then(|node| match node {
ast::AnyNodeRef::Arguments(args) => Some(args),
@@ -1679,13 +1716,9 @@ impl<'t> CompletionTargetTokens<'t> {
let node = cursor.covering_node(token.range()).node();
Some(CompletionTargetAst::Scoped(ScopedTarget { node }))
}
CompletionTargetTokens::Unknown => {
let range = TextRange::empty(cursor.offset);
let covering_node = cursor.covering_node(range);
Some(CompletionTargetAst::Scoped(ScopedTarget {
node: covering_node.node(),
}))
}
CompletionTargetTokens::Unknown => Some(CompletionTargetAst::Scoped(ScopedTarget {
node: cursor.covering_node.node(),
})),
}
}
}
@@ -3344,9 +3377,9 @@ class Foo(<CURSOR>):
);
assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @"
metaclass=
Bar
Foo
metaclass=
");
}
@@ -3362,9 +3395,9 @@ class Bar: ...
);
assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @"
metaclass=
Bar
Foo
metaclass=
");
}
@@ -3380,9 +3413,9 @@ class Bar: ...
);
assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @"
metaclass=
Bar
Foo
metaclass=
");
}
@@ -3396,9 +3429,9 @@ class Foo(<CURSOR>",
);
assert_snapshot!(builder.skip_keywords().skip_builtins().build().snapshot(), @"
metaclass=
Bar
Foo
metaclass=
");
}