diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index be1e0bfa4f..0359d52979 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -490,15 +490,13 @@ pub fn completion<'db>( !ty.is_notimplemented(db) }); } - if is_specifying_for_statement_iterable(&parsed, offset, typed.as_deref()) { - // Remove all keywords that doesn't make sense given the context, - // even if they are syntatically valid, e.g. `None`. + if let Some(keywords) = context_valid_keywords(&parsed, offset, typed.as_deref(), tokens) { completions.retain(|item| { let Some(kind) = item.kind else { return true }; if kind != CompletionKind::Keyword { return true; } - matches!(item.name.as_str(), "await" | "lambda" | "yield") + keywords.contains(item.name.as_str()) }); } completions.into_completions() @@ -1673,23 +1671,186 @@ fn is_raising_exception(tokens: &[Token]) -> bool { false } -/// Returns true when the cursor is after the `in` keyword in a -/// `for x in ` statement. -fn is_specifying_for_statement_iterable( +// Returns a set of keywords that are valid at +// the current cursor position. +// +// Returns None if no context-based exclusions can +// be identified. Meaning that all keywords are valid. +fn context_valid_keywords( parsed: &ParsedModuleRef, offset: TextSize, typed: Option<&str>, -) -> bool { + tokens: &[Token], +) -> Option> { let range = typed_text_range(typed, offset); + if is_in_decorator_expression(tokens) { + return Some(["lambda"].into_iter().collect()); + } let covering = covering_node(parsed.syntax().into(), range); - covering.parent().is_some_and(|node| { - matches!( - node, ast::AnyNodeRef::StmtFor(stmt_for) if stmt_for.iter.range().contains_range(range) - ) + + covering.ancestors().find_map(|node| { + is_in_for_statement_iterable(node, offset, typed) + .then_some(["yield", "lambda", "await"].into_iter().collect()) + .or_else(|| { + is_expecting_expression(node, range).then_some( + [ + "await", "lambda", "yield", "for", "if", "else", "and", "or", "not", "in", + "is", "True", "False", "None", + ] + .into_iter() + .collect(), + ) + }) }) } +/// Returns true if the cursor is after an `@` token +/// that corresponds to a decorator declaration +/// +/// `@` can also be used as an operator, this distinguishes +/// between the two usages and only looks for the decorator case. +fn is_in_decorator_expression(tokens: &[Token]) -> bool { + const LIMIT: usize = 10; + enum S { + Start, + At, + } + let mut state = S::Start; + for token in tokens.iter().rev().take(LIMIT) { + // Matches lines that starts with `@` as + // heuristic for decorators. When decorators + // are constructed they are often not identified + // as decorators yet by the AST, hence we use + // token matching for the decorator case. + // + // As the grammar also allows @ to be used as an operator, + // we want to distinguish between whether it looks + // like it's being used as an operator or a + // decorator. + // + // TODO: This doesn't handle decorators + // that start at the very top of the file. + state = match (state, token.kind()) { + (S::Start, TokenKind::Newline | TokenKind::Indent | TokenKind::Dedent) => break, + (S::Start, TokenKind::At) => S::At, + (S::Start, _) => S::Start, + ( + S::At, + TokenKind::Newline + | TokenKind::NonLogicalNewline + | TokenKind::Indent + | TokenKind::Dedent, + ) => { + return true; + } + _ => break, + } + } + false +} + +/// Returns true when only an expression is valid after the cursor +/// according to the python grammar. +fn is_expecting_expression(node: ast::AnyNodeRef, range: TextRange) -> bool { + let contains = |expr: &ast::Expr| expr.range().contains_range(range); + + match node { + // All checks here are intended to find cases where + // the python grammar disallows anything but expressions. + + // if_stmt := 'if' named_expression ':' block elif_stmt + ast::AnyNodeRef::StmtIf(stmt) => { + contains(&stmt.test) + || stmt + .elif_else_clauses + .iter() + .any(|clause| clause.test.as_ref().is_some_and(contains)) + } + // while_stmt := 'while' named_expression ':' block [else_block] + ast::AnyNodeRef::StmtWhile(stmt) => contains(&stmt.test), + + // for_stmt := 'for' star_targets 'in' ~ star_expressions ':' [TYPE_COMMENT] block [else_block] + ast::AnyNodeRef::StmtFor(stmt) => contains(&stmt.iter), + // with_item := expression + ast::AnyNodeRef::StmtWith(stmt) => { + stmt.items.iter().any(|item| contains(&item.context_expr)) + } + // match_stmt := "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT + ast::AnyNodeRef::StmtMatch(stmt) => contains(&stmt.subject), + // case_guard := 'if' named_expression + ast::AnyNodeRef::MatchCase(case) => case.guard.as_deref().is_some_and(contains), + // assert_stmt := 'assert' expression [',' expression ] + ast::AnyNodeRef::StmtAssert(stmt) => { + contains(&stmt.test) || stmt.msg.as_deref().is_some_and(contains) + } + // raise_stmt := 'raise' expression ['from' expression ] + ast::AnyNodeRef::StmtRaise(stmt) => { + stmt.exc.as_deref().is_some_and(contains) || stmt.cause.as_deref().is_some_and(contains) + } + // return_stmt := 'return' [star_expressions] + ast::AnyNodeRef::StmtReturn(stmt) => stmt.value.as_deref().is_some_and(contains), + + ast::AnyNodeRef::StmtAssign(stmt) => contains(&stmt.value), + ast::AnyNodeRef::StmtAugAssign(stmt) => contains(&stmt.value), + ast::AnyNodeRef::StmtAnnAssign(stmt) => { + contains(&stmt.annotation) || stmt.value.as_deref().is_some_and(contains) + } + // type_alias := "type" NAME [type_params] '=' expression + ast::AnyNodeRef::StmtTypeAlias(stmt) => contains(&stmt.value), + // except_clause := 'except' expression ':' block + ast::AnyNodeRef::ExceptHandlerExceptHandler(handler) => { + handler.type_.as_deref().is_some_and(contains) + } + + ast::AnyNodeRef::ExprList(expr) => expr.elts.iter().any(contains), + ast::AnyNodeRef::ExprTuple(expr) => expr.elts.iter().any(contains), + ast::AnyNodeRef::ExprSet(expr) => expr.elts.iter().any(contains), + ast::AnyNodeRef::ExprDict(expr) => expr + .items + .iter() + .any(|item| item.range().contains_range(range)), + + // arguments := (positional arguments | keyword arguments | "*" expression | "**" expression)* + ast::AnyNodeRef::Arguments(args) => { + args.args.iter().any(contains) || args.keywords.iter().any(|kw| contains(&kw.value)) + } + // with_item := expression + ast::AnyNodeRef::WithItem(item) => contains(&item.context_expr), + + // lambdef := 'lambda' [lambda_params] ':' expression + ast::AnyNodeRef::ExprLambda(expr) => contains(&expr.body), + + ast::AnyNodeRef::Parameter(param) => param.annotation.as_deref().is_some_and(contains), + ast::AnyNodeRef::ParameterWithDefault(param) => { + param.default.as_deref().is_some_and(contains) + || param.parameter.annotation.as_deref().is_some_and(contains) + } + _ => false, + } +} + +/// Returns true when the cursor is after the `in` keyword in a +/// `for x in ` statement. +fn is_in_for_statement_iterable(node: AnyNodeRef, offset: TextSize, typed: Option<&str>) -> bool { + let range = typed_text_range(typed, offset); + + match node { + ast::AnyNodeRef::StmtFor(stmt_for) => stmt_for.iter.range().contains_range(range), + // Detects `for x in ` statements inside comprehensions. + // E.g. `[for x in ]` + ast::AnyNodeRef::Comprehension(comprehension) => { + comprehension.target.range().contains_range(range) + || comprehension.iter.range().contains_range(range) + || comprehension + .ifs + .iter() + .any(|expr| expr.range().contains_range(range)) + } + _ => false, + } +} + /// Returns the `TextRange` of the `typed` text. /// /// `typed` should be the text immediately before the @@ -5902,7 +6063,7 @@ def foo(param: s) } #[test] - fn no_statement_keywords_in_for_statement_simple1() { + fn iterable_only_keywords_in_for_statement_simple1() { completion_test_builder( "\ for x in a @@ -5916,7 +6077,7 @@ for x in a } #[test] - fn no_statement_keywords_in_for_statement_simple2() { + fn iterable_only_keywords_in_for_statement_simple2() { completion_test_builder( "\ for x, y, _ in a @@ -5930,7 +6091,7 @@ for x, y, _ in a } #[test] - fn no_statement_keywords_in_for_statement_simple3() { + fn iterable_only_keywords_in_for_statement_simple3() { completion_test_builder( "\ for i, (x, y, z) in a @@ -5944,7 +6105,7 @@ for i, (x, y, z) in a } #[test] - fn no_statement_keywords_in_for_statement_complex() { + fn iterable_only_keywords_in_for_statement_complex() { completion_test_builder( "\ for i, (obj.x, (a[0], b['k']), _), *rest in a @@ -5957,6 +6118,206 @@ for i, (obj.x, (a[0], b['k']), _), *rest in a .not_contains("False"); } + #[test] + fn no_statement_keywords_in_if_condition() { + completion_test_builder( + "\ +if a: + pass +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + // FIXME: Should not contain False as we're expecting an iterable + // or a container. + #[test] + fn no_statement_keywords_in_if_x_in() { + completion_test_builder( + "\ +if x in a: + pass +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_return_value() { + completion_test_builder( + "\ +def func(): + return a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_match() { + completion_test_builder( + "\ +match a: + case _: + pass +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + // FIXME: Suggesting False doesn't make much sense here. + #[test] + fn no_statement_keywords_in_with() { + completion_test_builder( + "\ +with a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_while() { + completion_test_builder( + "\ +while a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_lambda() { + completion_test_builder( + "\ +lambda foo: a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn only_lambda_keyword_in_decorator() { + // The decorator check currently doesn't work at + // the very start of the file, therefore the code + // in this test explicitly has a variable + // defined before it + completion_test_builder( + "\ +foo = 123 + +@a +def func(): + ... +", + ) + .build() + .contains("lambda") + .not_contains("await") + .not_contains("raise") + .not_contains("False"); + } + + #[test] + fn statement_keywords_in_if_body() { + completion_test_builder( + "\ +foo = 123 +if foo: + a +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("raise") + .contains("False"); + } + + #[test] + fn no_statement_keywords_in_tuple() { + completion_test_builder( + "\ +(a,) +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_set_literal() { + completion_test_builder( + "\ +{a} +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + + #[test] + fn iterable_only_in_comprehension() { + completion_test_builder( + "\ +[x for x in a] +", + ) + .build() + .contains("lambda") + .contains("await") + .not_contains("False") + .not_contains("raise"); + } + + #[test] + fn no_statement_keywords_in_dict_literal() { + completion_test_builder( + "\ +{a: 1} +", + ) + .build() + .contains("lambda") + .contains("await") + .contains("False") + .not_contains("raise"); + } + #[test] fn favour_symbols_currently_imported() { let snapshot = CursorTest::builder()