diff --git a/crates/ruff/resources/test/fixtures/flake8_pie/PIE802.py b/crates/ruff/resources/test/fixtures/flake8_pie/PIE802.py index 52ed90aa21..21ae8a5d28 100644 --- a/crates/ruff/resources/test/fixtures/flake8_pie/PIE802.py +++ b/crates/ruff/resources/test/fixtures/flake8_pie/PIE802.py @@ -8,3 +8,9 @@ any({x.id for x in bar}) # PIE 802 any([x.id for x in bar]) all([x.id for x in bar]) +any( # first comment + [x.id for x in bar], # second comment +) # third comment +all( # first comment + [x.id for x in bar], # second comment +) # third comment diff --git a/crates/ruff/src/rules/flake8_pie/fixes.rs b/crates/ruff/src/rules/flake8_pie/fixes.rs index 53918a7fc8..48ecc1ff26 100644 --- a/crates/ruff/src/rules/flake8_pie/fixes.rs +++ b/crates/ruff/src/rules/flake8_pie/fixes.rs @@ -4,7 +4,7 @@ use libcst_native::{Codegen, CodegenState, Expression, GeneratorExp}; use ruff_python_ast::source_code::{Locator, Stylist}; use ruff_python_ast::types::Range; -use crate::cst::matchers::{match_expr, match_module}; +use crate::cst::matchers::{match_call, match_expression}; use crate::fix::Fix; /// (PIE802) Convert `[i for i in a]` into `i for i in a` @@ -14,23 +14,28 @@ pub fn fix_unnecessary_comprehension_any_all( expr: &rustpython_parser::ast::Expr, ) -> Result { // Expr(ListComp) -> Expr(GeneratorExp) - let module_text = locator.slice(Range::from_located(expr)); - let mut tree = match_module(module_text)?; - let mut body = match_expr(&mut tree)?; + let expression_text = locator.slice(Range::from_located(expr)); + let mut tree = match_expression(expression_text)?; + let call = match_call(&mut tree)?; - let Expression::ListComp(list_comp) = &body.value else { + let Expression::ListComp(list_comp) = &call.args[0].value else { bail!( "Expected Expression::ListComp" ); }; - body.value = Expression::GeneratorExp(Box::new(GeneratorExp { + call.args[0].value = Expression::GeneratorExp(Box::new(GeneratorExp { elt: list_comp.elt.clone(), for_in: list_comp.for_in.clone(), lpar: list_comp.lpar.clone(), rpar: list_comp.rpar.clone(), })); + if let Some(comma) = &call.args[0].comma { + call.args[0].whitespace_after_arg = comma.whitespace_after.clone(); + call.args[0].comma = None; + } + let mut state = CodegenState { default_newline: stylist.line_ending(), default_indent: stylist.indentation(), diff --git a/crates/ruff/src/rules/flake8_pie/rules.rs b/crates/ruff/src/rules/flake8_pie/rules.rs index 1876b776b5..655892f276 100644 --- a/crates/ruff/src/rules/flake8_pie/rules.rs +++ b/crates/ruff/src/rules/flake8_pie/rules.rs @@ -340,13 +340,15 @@ pub fn unnecessary_comprehension_any_all( return; } if let ExprKind::ListComp { .. } = args[0].node { - let mut diagnostic = - Diagnostic::new(UnnecessaryComprehensionAnyAll, Range::from_located(expr)); + let mut diagnostic = Diagnostic::new( + UnnecessaryComprehensionAnyAll, + Range::from_located(&args[0]), + ); if checker.patch(diagnostic.kind.rule()) { match fixes::fix_unnecessary_comprehension_any_all( checker.locator, checker.stylist, - &args[0], + expr, ) { Ok(fix) => { diagnostic.amend(fix); diff --git a/crates/ruff/src/rules/flake8_pie/snapshots/ruff__rules__flake8_pie__tests__PIE802_PIE802.py.snap b/crates/ruff/src/rules/flake8_pie/snapshots/ruff__rules__flake8_pie__tests__PIE802_PIE802.py.snap index af293532d6..62363bcf13 100644 --- a/crates/ruff/src/rules/flake8_pie/snapshots/ruff__rules__flake8_pie__tests__PIE802_PIE802.py.snap +++ b/crates/ruff/src/rules/flake8_pie/snapshots/ruff__rules__flake8_pie__tests__PIE802_PIE802.py.snap @@ -6,34 +6,68 @@ expression: diagnostics UnnecessaryComprehensionAnyAll: ~ location: row: 9 - column: 0 + column: 4 end_location: row: 9 - column: 24 + column: 23 fix: - content: x.id for x in bar + content: any(x.id for x in bar) location: row: 9 - column: 4 + column: 0 end_location: row: 9 - column: 23 + column: 24 parent: ~ - kind: UnnecessaryComprehensionAnyAll: ~ location: row: 10 - column: 0 + column: 4 end_location: row: 10 - column: 24 + column: 23 fix: - content: x.id for x in bar + content: all(x.id for x in bar) location: row: 10 - column: 4 + column: 0 end_location: row: 10 - column: 23 + column: 24 + parent: ~ +- kind: + UnnecessaryComprehensionAnyAll: ~ + location: + row: 12 + column: 4 + end_location: + row: 12 + column: 23 + fix: + content: "any( # first comment\n x.id for x in bar # second comment\n)" + location: + row: 11 + column: 0 + end_location: + row: 13 + column: 1 + parent: ~ +- kind: + UnnecessaryComprehensionAnyAll: ~ + location: + row: 15 + column: 4 + end_location: + row: 15 + column: 23 + fix: + content: "all( # first comment\n x.id for x in bar # second comment\n)" + location: + row: 14 + column: 0 + end_location: + row: 16 + column: 1 parent: ~