[ty] Don't suggest keyword statements when only expressions are valid

There are cases where the python grammar enforces expressions
after certain statements. In such cases we want to suppress
irrelevant keywords from the auto-complete suggestions.

E.g. `with a<CURSOR>`, suggesting `raise` here never makes sense
because it is not valid by the grammar.
This commit is contained in:
Rasmus Nygren 2025-12-15 21:53:06 +01:00
parent 5c942119f8
commit ea4eca38f1
1 changed files with 377 additions and 16 deletions

View File

@ -490,15 +490,13 @@ pub fn completion<'db>(
!ty.is_notimplemented(db)
});
}
if is_specifying_for_statement_iterable(&parsed, offset, typed.as_deref()) {
// Remove all keywords that doesn't make sense given the context,
// even if they are syntatically valid, e.g. `None`.
if let Some(keywords) = context_valid_keywords(&parsed, offset, typed.as_deref(), tokens) {
completions.retain(|item| {
let Some(kind) = item.kind else { return true };
if kind != CompletionKind::Keyword {
return true;
}
matches!(item.name.as_str(), "await" | "lambda" | "yield")
keywords.contains(item.name.as_str())
});
}
completions.into_completions()
@ -1673,21 +1671,184 @@ fn is_raising_exception(tokens: &[Token]) -> bool {
false
}
/// Returns true when the cursor is after the `in` keyword in a
/// `for x in <CURSOR>` statement.
fn is_specifying_for_statement_iterable(
// Returns a set of keywords that are valid at
// the current cursor position.
//
// Returns None if no context-based exclusions can
// be identified. Meaning that all keywords are valid.
fn context_valid_keywords(
parsed: &ParsedModuleRef,
offset: TextSize,
typed: Option<&str>,
) -> bool {
tokens: &[Token],
) -> Option<FxHashSet<&'static str>> {
let range = typed_text_range(typed, offset);
if is_in_decorator_expression(tokens) {
return Some(["lambda"].into_iter().collect());
}
let covering = covering_node(parsed.syntax().into(), range);
covering.parent().is_some_and(|node| {
matches!(
node, ast::AnyNodeRef::StmtFor(stmt_for) if stmt_for.iter.range().contains_range(range)
covering.ancestors().find_map(|node| {
is_in_for_statement_iterable(node, offset, typed)
.then_some(["yield", "lambda", "await"].into_iter().collect())
.or_else(|| {
is_expecting_expression(node, range).then_some(
[
"await", "lambda", "yield", "for", "if", "else", "and", "or", "not", "in",
"is", "True", "False", "None",
]
.into_iter()
.collect(),
)
})
})
}
/// Returns true if the cursor is after an `@` token
/// that corresponds to a decorator declaration
///
/// `@` can also be used as an operator, this distinguishes
/// between the two usages and only looks for the decorator case.
fn is_in_decorator_expression(tokens: &[Token]) -> bool {
const LIMIT: usize = 10;
enum S {
Start,
At,
}
let mut state = S::Start;
for token in tokens.iter().rev().take(LIMIT) {
// Matches lines that starts with `@` as
// heuristic for decorators. When decorators
// are constructed they are often not identified
// as decorators yet by the AST, hence we use
// token matching for the decorator case.
//
// As the grammar also allows @ to be used as an operator,
// we want to distinguish between whether it looks
// like it's being used as an operator or a
// decorator.
//
// TODO: This doesn't handle decorators
// that start at the very top of the file.
state = match (state, token.kind()) {
(S::Start, TokenKind::Newline | TokenKind::Indent | TokenKind::Dedent) => break,
(S::Start, TokenKind::At) => S::At,
(S::Start, _) => S::Start,
(
S::At,
TokenKind::Newline
| TokenKind::NonLogicalNewline
| TokenKind::Indent
| TokenKind::Dedent,
) => {
return true;
}
_ => break,
}
}
false
}
/// Returns true when only an expression is valid after the cursor
/// according to the python grammar.
fn is_expecting_expression(node: ast::AnyNodeRef, range: TextRange) -> bool {
let contains = |expr: &ast::Expr| expr.range().contains_range(range);
match node {
// All checks here are intended to find cases where
// the python grammar disallows anything but expressions.
// if_stmt := 'if' named_expression ':' block elif_stmt
ast::AnyNodeRef::StmtIf(stmt) => {
contains(&stmt.test)
|| stmt
.elif_else_clauses
.iter()
.any(|clause| clause.test.as_ref().is_some_and(contains))
}
// while_stmt := 'while' named_expression ':' block [else_block]
ast::AnyNodeRef::StmtWhile(stmt) => contains(&stmt.test),
// for_stmt := 'for' star_targets 'in' ~ star_expressions ':' [TYPE_COMMENT] block [else_block]
ast::AnyNodeRef::StmtFor(stmt) => contains(&stmt.iter),
// with_item := expression
ast::AnyNodeRef::StmtWith(stmt) => {
stmt.items.iter().any(|item| contains(&item.context_expr))
}
// match_stmt := "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT
ast::AnyNodeRef::StmtMatch(stmt) => contains(&stmt.subject),
// case_guard := 'if' named_expression
ast::AnyNodeRef::MatchCase(case) => case.guard.as_deref().is_some_and(contains),
// assert_stmt := 'assert' expression [',' expression ]
ast::AnyNodeRef::StmtAssert(stmt) => {
contains(&stmt.test) || stmt.msg.as_deref().is_some_and(contains)
}
// raise_stmt := 'raise' expression ['from' expression ]
ast::AnyNodeRef::StmtRaise(stmt) => {
stmt.exc.as_deref().is_some_and(contains) || stmt.cause.as_deref().is_some_and(contains)
}
// return_stmt := 'return' [star_expressions]
ast::AnyNodeRef::StmtReturn(stmt) => stmt.value.as_deref().is_some_and(contains),
ast::AnyNodeRef::StmtAssign(stmt) => contains(&stmt.value),
ast::AnyNodeRef::StmtAugAssign(stmt) => contains(&stmt.value),
ast::AnyNodeRef::StmtAnnAssign(stmt) => {
contains(&stmt.annotation) || stmt.value.as_deref().is_some_and(contains)
}
// type_alias := "type" NAME [type_params] '=' expression
ast::AnyNodeRef::StmtTypeAlias(stmt) => contains(&stmt.value),
// except_clause := 'except' expression ':' block
ast::AnyNodeRef::ExceptHandlerExceptHandler(handler) => {
handler.type_.as_deref().is_some_and(contains)
}
ast::AnyNodeRef::ExprList(expr) => expr.elts.iter().any(contains),
ast::AnyNodeRef::ExprTuple(expr) => expr.elts.iter().any(contains),
ast::AnyNodeRef::ExprSet(expr) => expr.elts.iter().any(contains),
ast::AnyNodeRef::ExprDict(expr) => expr
.items
.iter()
.any(|item| item.range().contains_range(range)),
// arguments := (positional arguments | keyword arguments | "*" expression | "**" expression)*
ast::AnyNodeRef::Arguments(args) => {
args.args.iter().any(contains) || args.keywords.iter().any(|kw| contains(&kw.value))
}
// with_item := expression
ast::AnyNodeRef::WithItem(item) => contains(&item.context_expr),
// lambdef := 'lambda' [lambda_params] ':' expression
ast::AnyNodeRef::ExprLambda(expr) => contains(&expr.body),
ast::AnyNodeRef::Parameter(param) => param.annotation.as_deref().is_some_and(contains),
ast::AnyNodeRef::ParameterWithDefault(param) => {
param.default.as_deref().is_some_and(contains)
|| param.parameter.annotation.as_deref().is_some_and(contains)
}
_ => false,
}
}
/// Returns true when the cursor is after the `in` keyword in a
/// `for x in <CURSOR>` statement.
fn is_in_for_statement_iterable(node: AnyNodeRef, offset: TextSize, typed: Option<&str>) -> bool {
let range = typed_text_range(typed, offset);
match node {
ast::AnyNodeRef::StmtFor(stmt_for) => stmt_for.iter.range().contains_range(range),
// Detects `for x in <CURSOR>` statements inside comprehensions.
// E.g. `[for x in <CURSOR>]`
ast::AnyNodeRef::Comprehension(comprehension) => {
comprehension.target.range().contains_range(range)
|| comprehension.iter.range().contains_range(range)
|| comprehension
.ifs
.iter()
.any(|expr| expr.range().contains_range(range))
}
_ => false,
}
}
/// Returns the `TextRange` of the `typed` text.
@ -5902,7 +6063,7 @@ def foo(param: s<CURSOR>)
}
#[test]
fn no_statement_keywords_in_for_statement_simple1() {
fn iterable_only_keywords_in_for_statement_simple1() {
completion_test_builder(
"\
for x in a<CURSOR>
@ -5916,7 +6077,7 @@ for x in a<CURSOR>
}
#[test]
fn no_statement_keywords_in_for_statement_simple2() {
fn iterable_only_keywords_in_for_statement_simple2() {
completion_test_builder(
"\
for x, y, _ in a<CURSOR>
@ -5930,7 +6091,7 @@ for x, y, _ in a<CURSOR>
}
#[test]
fn no_statement_keywords_in_for_statement_simple3() {
fn iterable_only_keywords_in_for_statement_simple3() {
completion_test_builder(
"\
for i, (x, y, z) in a<CURSOR>
@ -5944,7 +6105,7 @@ for i, (x, y, z) in a<CURSOR>
}
#[test]
fn no_statement_keywords_in_for_statement_complex() {
fn iterable_only_keywords_in_for_statement_complex() {
completion_test_builder(
"\
for i, (obj.x, (a[0], b['k']), _), *rest in a<CURSOR>
@ -5957,6 +6118,206 @@ for i, (obj.x, (a[0], b['k']), _), *rest in a<CURSOR>
.not_contains("False");
}
#[test]
fn no_statement_keywords_in_if_condition() {
completion_test_builder(
"\
if a<CURSOR>:
pass
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
// FIXME: Should not contain False as we're expecting an iterable
// or a container.
#[test]
fn no_statement_keywords_in_if_x_in() {
completion_test_builder(
"\
if x in a<CURSOR>:
pass
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn no_statement_keywords_in_return_value() {
completion_test_builder(
"\
def func():
return a<CURSOR>
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn no_statement_keywords_in_match() {
completion_test_builder(
"\
match a<CURSOR>:
case _:
pass
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
// FIXME: Suggesting False doesn't make much sense here.
#[test]
fn no_statement_keywords_in_with() {
completion_test_builder(
"\
with a<CURSOR>
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn no_statement_keywords_in_while() {
completion_test_builder(
"\
while a<CURSOR>
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn no_statement_keywords_in_lambda() {
completion_test_builder(
"\
lambda foo: a<CURSOR>
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn only_lambda_keyword_in_decorator() {
// The decorator check currently doesn't work at
// the very start of the file, therefore the code
// in this test explicitly has a variable
// defined before it
completion_test_builder(
"\
foo = 123
@a<CURSOR>
def func():
...
",
)
.build()
.contains("lambda")
.not_contains("await")
.not_contains("raise")
.not_contains("False");
}
#[test]
fn statement_keywords_in_if_body() {
completion_test_builder(
"\
foo = 123
if foo:
a<CURSOR>
",
)
.build()
.contains("lambda")
.contains("await")
.contains("raise")
.contains("False");
}
#[test]
fn no_statement_keywords_in_tuple() {
completion_test_builder(
"\
(a<CURSOR>,)
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn no_statement_keywords_in_set_literal() {
completion_test_builder(
"\
{a<CURSOR>}
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn iterable_only_in_comprehension() {
completion_test_builder(
"\
[x for x in a<CURSOR>]
",
)
.build()
.contains("lambda")
.contains("await")
.not_contains("False")
.not_contains("raise");
}
#[test]
fn no_statement_keywords_in_dict_literal() {
completion_test_builder(
"\
{a<CURSOR>: 1}
",
)
.build()
.contains("lambda")
.contains("await")
.contains("False")
.not_contains("raise");
}
#[test]
fn favour_symbols_currently_imported() {
let snapshot = CursorTest::builder()