diff --git a/crates/ruff_linter/resources/test/fixtures/pyupgrade/UP043.py b/crates/ruff_linter/resources/test/fixtures/pyupgrade/UP043.py index 3968a16a54..5613053d81 100644 --- a/crates/ruff_linter/resources/test/fixtures/pyupgrade/UP043.py +++ b/crates/ruff_linter/resources/test/fixtures/pyupgrade/UP043.py @@ -50,3 +50,10 @@ def func() -> Generator[str, None, None]: async def func() -> AsyncGenerator[str, None]: yield "hello" + + +async def func() -> AsyncGenerator[ # type: ignore + str, + None +]: + yield "hello" diff --git a/crates/ruff_linter/src/rules/pyupgrade/rules/unnecessary_default_type_args.rs b/crates/ruff_linter/src/rules/pyupgrade/rules/unnecessary_default_type_args.rs index 1e26ada05b..7a9d36249b 100644 --- a/crates/ruff_linter/src/rules/pyupgrade/rules/unnecessary_default_type_args.rs +++ b/crates/ruff_linter/src/rules/pyupgrade/rules/unnecessary_default_type_args.rs @@ -1,4 +1,4 @@ -use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix}; +use ruff_diagnostics::{AlwaysFixableViolation, Applicability, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::{self as ast, Expr}; use ruff_text_size::{Ranged, TextRange}; @@ -44,6 +44,9 @@ use crate::checkers::ast::Checker; /// yield 42 /// ``` /// +/// ## Fix safety +/// This rule's fix is marked as safe, unless the type annotation contains comments. +/// /// ## References /// - [PEP 696 – Type Defaults for Type Parameters](https://peps.python.org/pep-0696/) /// - [Annotating generators and coroutines](https://docs.python.org/3/library/typing.html#annotating-generators-and-coroutines) @@ -93,26 +96,39 @@ pub(crate) fn unnecessary_default_type_args(checker: &mut Checker, expr: &Expr) } let mut diagnostic = Diagnostic::new(UnnecessaryDefaultTypeArgs, expr.range()); - diagnostic.set_fix(Fix::safe_edit(Edit::range_replacement( - checker - .generator() - .expr(&Expr::Subscript(ast::ExprSubscript { - value: value.clone(), - slice: Box::new(if let [elt] = valid_elts.as_slice() { - elt.clone() - } else { - Expr::Tuple(ast::ExprTuple { - elts: valid_elts, - ctx: ast::ExprContext::Load, - range: TextRange::default(), - parenthesized: true, - }) - }), - ctx: ast::ExprContext::Load, - range: TextRange::default(), - })), - expr.range(), - ))); + + let applicability = if checker + .comment_ranges() + .has_comments(expr, checker.source()) + { + Applicability::Unsafe + } else { + Applicability::Safe + }; + + diagnostic.set_fix(Fix::applicable_edit( + Edit::range_replacement( + checker + .generator() + .expr(&Expr::Subscript(ast::ExprSubscript { + value: value.clone(), + slice: Box::new(if let [elt] = valid_elts.as_slice() { + elt.clone() + } else { + Expr::Tuple(ast::ExprTuple { + elts: valid_elts, + ctx: ast::ExprContext::Load, + range: TextRange::default(), + parenthesized: true, + }) + }), + ctx: ast::ExprContext::Load, + range: TextRange::default(), + })), + expr.range(), + ), + applicability, + )); checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff_linter/src/rules/pyupgrade/snapshots/ruff_linter__rules__pyupgrade__tests__UP043.py.snap b/crates/ruff_linter/src/rules/pyupgrade/snapshots/ruff_linter__rules__pyupgrade__tests__UP043.py.snap index 3c64e318f1..dda6bd47ee 100644 --- a/crates/ruff_linter/src/rules/pyupgrade/snapshots/ruff_linter__rules__pyupgrade__tests__UP043.py.snap +++ b/crates/ruff_linter/src/rules/pyupgrade/snapshots/ruff_linter__rules__pyupgrade__tests__UP043.py.snap @@ -108,3 +108,28 @@ UP043.py:51:21: UP043 [*] Unnecessary default type arguments 51 |-async def func() -> AsyncGenerator[str, None]: 51 |+async def func() -> AsyncGenerator[str]: 52 52 | yield "hello" +53 53 | +54 54 | + +UP043.py:55:21: UP043 [*] Unnecessary default type arguments + | +55 | async def func() -> AsyncGenerator[ # type: ignore + | _____________________^ +56 | | str, +57 | | None +58 | | ]: + | |_^ UP043 +59 | yield "hello" + | + = help: Remove default type arguments + +ℹ Unsafe fix +52 52 | yield "hello" +53 53 | +54 54 | +55 |-async def func() -> AsyncGenerator[ # type: ignore +56 |- str, +57 |- None +58 |-]: + 55 |+async def func() -> AsyncGenerator[str]: +59 56 | yield "hello"