From ea4eca38f19a0bfa06b96030798897bcb5b502a5 Mon Sep 17 00:00:00 2001 From: Rasmus Nygren Date: Mon, 15 Dec 2025 21:53:06 +0100 Subject: [PATCH] [ty] Don't suggest keyword statements when only expressions are valid There are cases where the python grammar enforces expressions after certain statements. In such cases we want to suppress irrelevant keywords from the auto-complete suggestions. E.g. `with a`, suggesting `raise` here never makes sense because it is not valid by the grammar. --- crates/ty_ide/src/completion.rs | 393 ++++++++++++++++++++++++++++++-- 1 file changed, 377 insertions(+), 16 deletions(-) diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index be1e0bfa4f..0359d52979 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -490,15 +490,13 @@ pub fn completion<'db>( !ty.is_notimplemented(db) }); } - if is_specifying_for_statement_iterable(&parsed, offset, typed.as_deref()) { - // Remove all keywords that doesn't make sense given the context, - // even if they are syntatically valid, e.g. `None`. + if let Some(keywords) = context_valid_keywords(&parsed, offset, typed.as_deref(), tokens) { completions.retain(|item| { let Some(kind) = item.kind else { return true }; if kind != CompletionKind::Keyword { return true; } - matches!(item.name.as_str(), "await" | "lambda" | "yield") + keywords.contains(item.name.as_str()) }); } completions.into_completions() @@ -1673,23 +1671,186 @@ fn is_raising_exception(tokens: &[Token]) -> bool { false } -/// Returns true when the cursor is after the `in` keyword in a -/// `for x in ` statement. -fn is_specifying_for_statement_iterable( +// Returns a set of keywords that are valid at +// the current cursor position. +// +// Returns None if no context-based exclusions can +// be identified. Meaning that all keywords are valid. +fn context_valid_keywords( parsed: &ParsedModuleRef, offset: TextSize, typed: Option<&str>, -) -> bool { + tokens: &[Token], +) -> Option> { let range = typed_text_range(typed, offset); + if is_in_decorator_expression(tokens) { + return Some(["lambda"].into_iter().collect()); + } let covering = covering_node(parsed.syntax().into(), range); - covering.parent().is_some_and(|node| { - matches!( - node, ast::AnyNodeRef::StmtFor(stmt_for) if stmt_for.iter.range().contains_range(range) - ) + + covering.ancestors().find_map(|node| { + is_in_for_statement_iterable(node, offset, typed) + .then_some(["yield", "lambda", "await"].into_iter().collect()) + .or_else(|| { + is_expecting_expression(node, range).then_some( + [ + "await", "lambda", "yield", "for", "if", "else", "and", "or", "not", "in", + "is", "True", "False", "None", + ] + .into_iter() + .collect(), + ) + }) }) } +/// Returns true if the cursor is after an `@` token +/// that corresponds to a decorator declaration +/// +/// `@` can also be used as an operator, this distinguishes +/// between the two usages and only looks for the decorator case. +fn is_in_decorator_expression(tokens: &[Token]) -> bool { + const LIMIT: usize = 10; + enum S { + Start, + At, + } + let mut state = S::Start; + for token in tokens.iter().rev().take(LIMIT) { + // Matches lines that starts with `@` as + // heuristic for decorators. When decorators + // are constructed they are often not identified + // as decorators yet by the AST, hence we use + // token matching for the decorator case. + // + // As the grammar also allows @ to be used as an operator, + // we want to distinguish between whether it looks + // like it's being used as an operator or a + // decorator. + // + // TODO: This doesn't handle decorators + // that start at the very top of the file. + state = match (state, token.kind()) { + (S::Start, TokenKind::Newline | TokenKind::Indent | TokenKind::Dedent) => break, + (S::Start, TokenKind::At) => S::At, + (S::Start, _) => S::Start, + ( + S::At, + TokenKind::Newline + | TokenKind::NonLogicalNewline + | TokenKind::Indent + | TokenKind::Dedent, + ) => { + return true; + } + _ => break, + } + } + false +} + +/// Returns true when only an expression is valid after the cursor +/// according to the python grammar. +fn is_expecting_expression(node: ast::AnyNodeRef, range: TextRange) -> bool { + let contains = |expr: &ast::Expr| expr.range().contains_range(range); + + match node { + // All checks here are intended to find cases where + // the python grammar disallows anything but expressions. + + // if_stmt := 'if' named_expression ':' block elif_stmt + ast::AnyNodeRef::StmtIf(stmt) => { + contains(&stmt.test) + || stmt + .elif_else_clauses + .iter() + .any(|clause| clause.test.as_ref().is_some_and(contains)) + } + // while_stmt := 'while' named_expression ':' block [else_block] + ast::AnyNodeRef::StmtWhile(stmt) => contains(&stmt.test), + + // for_stmt := 'for' star_targets 'in' ~ star_expressions ':' [TYPE_COMMENT] block [else_block] + ast::AnyNodeRef::StmtFor(stmt) => contains(&stmt.iter), + // with_item := expression + ast::AnyNodeRef::StmtWith(stmt) => { + stmt.items.iter().any(|item| contains(&item.context_expr)) + } + // match_stmt := "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT + ast::AnyNodeRef::StmtMatch(stmt) => contains(&stmt.subject), + // case_guard := 'if' named_expression + ast::AnyNodeRef::MatchCase(case) => case.guard.as_deref().is_some_and(contains), + // assert_stmt := 'assert' expression [',' expression ] + ast::AnyNodeRef::StmtAssert(stmt) => { + contains(&stmt.test) || stmt.msg.as_deref().is_some_and(contains) + } + // raise_stmt := 'raise' expression ['from' expression ] + ast::AnyNodeRef::StmtRaise(stmt) => { + stmt.exc.as_deref().is_some_and(contains) || stmt.cause.as_deref().is_some_and(contains) + } + // return_stmt := 'return' [star_expressions] + ast::AnyNodeRef::StmtReturn(stmt) => stmt.value.as_deref().is_some_and(contains), + + ast::AnyNodeRef::StmtAssign(stmt) => contains(&stmt.value), + ast::AnyNodeRef::StmtAugAssign(stmt) => contains(&stmt.value), + ast::AnyNodeRef::StmtAnnAssign(stmt) => { + contains(&stmt.annotation) || stmt.value.as_deref().is_some_and(contains) + } + // type_alias := "type" NAME [type_params] '=' expression + ast::AnyNodeRef::StmtTypeAlias(stmt) => contains(&stmt.value), + // except_clause := 'except' expression ':' block + ast::AnyNodeRef::ExceptHandlerExceptHandler(handler) => { + handler.type_.as_deref().is_some_and(contains) + } + + ast::AnyNodeRef::ExprList(expr) => expr.elts.iter().any(contains), + ast::AnyNodeRef::ExprTuple(expr) => expr.elts.iter().any(contains), + ast::AnyNodeRef::ExprSet(expr) => expr.elts.iter().any(contains), + ast::AnyNodeRef::ExprDict(expr) => expr + .items + .iter() + .any(|item| item.range().contains_range(range)), + + // arguments := (positional arguments | keyword arguments | "*" expression | "**" expression)* + ast::AnyNodeRef::Arguments(args) => { + args.args.iter().any(contains) || args.keywords.iter().any(|kw| contains(&kw.value)) + } + // with_item := expression + ast::AnyNodeRef::WithItem(item) => contains(&item.context_expr), + + // lambdef := 'lambda' [lambda_params] ':' expression + ast::AnyNodeRef::ExprLambda(expr) => contains(&expr.body), + + ast::AnyNodeRef::Parameter(param) => param.annotation.as_deref().is_some_and(contains), + ast::AnyNodeRef::ParameterWithDefault(param) => { + param.default.as_deref().is_some_and(contains) + || param.parameter.annotation.as_deref().is_some_and(contains) + } + _ => false, + } +} + +/// Returns true when the cursor is after the `in` keyword in a +/// `for x in ` statement. +fn is_in_for_statement_iterable(node: AnyNodeRef, offset: TextSize, typed: Option<&str>) -> bool { + let range = typed_text_range(typed, offset); + + match node { + ast::AnyNodeRef::StmtFor(stmt_for) => stmt_for.iter.range().contains_range(range), + // Detects `for x in ` statements inside comprehensions. + // E.g. `[for x in ]` + ast::AnyNodeRef::Comprehension(comprehension) => { + comprehension.target.range().contains_range(range) + || comprehension.iter.range().contains_range(range) + || comprehension + .ifs + .iter() + .any(|expr| expr.range().contains_range(range)) + } + _ => false, + } +} + /// Returns the `TextRange` of the `typed` text. /// /// `typed` should be the text immediately before the @@ -5902,7 +6063,7 @@ def foo(param: s) } #[test] - fn no_statement_keywords_in_for_statement_simple1() { + fn iterable_only_keywords_in_for_statement_simple1() { completion_test_builder( "\ for x in a @@ -5916,7 +6077,7 @@ for x in a } #[test] - fn no_statement_keywords_in_for_statement_simple2() { + fn iterable_only_keywords_in_for_statement_simple2() { completion_test_builder( "\ for x, y, _ in a @@ -5930,7 +6091,7 @@ for x, y, _ in a } #[test] - fn no_statement_keywords_in_for_statement_simple3() { + fn iterable_only_keywords_in_for_statement_simple3() { completion_test_builder( "\ for i, (x, y, z) in a @@ -5944,7 +6105,7 @@ for i, (x, y, z) in a } #[test] - fn no_statement_keywords_in_for_statement_complex() { + fn iterable_only_keywords_in_for_statement_complex() { completion_test_builder( "\ for i, (obj.x, (a[0], b['k']), _), *rest in a @@ -5957,6 +6118,206 @@ for i, (obj.x, (a[0], b['k']), _), *rest in a .not_contains("False"); } + #[test] + fn no_statement_keywords_in_if_condition() { + completion_test_builder( + "\ +if a: + pass +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + // FIXME: Should not contain False as we're expecting an iterable + // or a container. + #[test] + fn no_statement_keywords_in_if_x_in() { + completion_test_builder( + "\ +if x in a: + pass +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_return_value() { + completion_test_builder( + "\ +def func(): + return a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_match() { + completion_test_builder( + "\ +match a: + case _: + pass +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + // FIXME: Suggesting False doesn't make much sense here. + #[test] + fn no_statement_keywords_in_with() { + completion_test_builder( + "\ +with a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_while() { + completion_test_builder( + "\ +while a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_lambda() { + completion_test_builder( + "\ +lambda foo: a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn only_lambda_keyword_in_decorator() { + // The decorator check currently doesn't work at + // the very start of the file, therefore the code + // in this test explicitly has a variable + // defined before it + completion_test_builder( + "\ +foo = 123 + +@a +def func(): + ... +", + ) + .build() + .contains("lambda") + .not_contains("await") + .not_contains("raise") + .not_contains("False"); + } + + #[test] + fn statement_keywords_in_if_body() { + completion_test_builder( + "\ +foo = 123 +if foo: + a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("raise") + .contains("False"); + } + + #[test] + fn no_statement_keywords_in_tuple() { + completion_test_builder( + "\ +(a,) +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_set_literal() { + completion_test_builder( + "\ +{a} +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn iterable_only_in_comprehension() { + completion_test_builder( + "\ +[x for x in a] +", + ) + .build() + .contains("lambda") + .contains("await") + .not_contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_dict_literal() { + completion_test_builder( + "\ +{a: 1} +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + #[test] fn favour_symbols_currently_imported() { let snapshot = CursorTest::builder()