diff --git a/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs b/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs index 8898e1df69..bb1805ba77 100644 --- a/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs +++ b/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs @@ -1,8 +1,8 @@ -use ruff_python_ast::{self as ast, Stmt}; +use ruff_python_ast::Stmt; use crate::checkers::ast::Checker; use crate::codes::Rule; -use crate::rules::{flake8_bugbear, perflint}; +use crate::rules::{flake8_bugbear, perflint, pyupgrade}; /// Run lint rules over all deferred for-loops in the [`SemanticModel`]. pub(crate) fn deferred_for_loops(checker: &mut Checker) { @@ -11,18 +11,18 @@ pub(crate) fn deferred_for_loops(checker: &mut Checker) { for snapshot in for_loops { checker.semantic.restore(snapshot); - let Stmt::For(ast::StmtFor { - target, iter, body, .. - }) = checker.semantic.current_statement() - else { + let Stmt::For(stmt_for) = checker.semantic.current_statement() else { unreachable!("Expected Stmt::For"); }; if checker.enabled(Rule::UnusedLoopControlVariable) { - flake8_bugbear::rules::unused_loop_control_variable(checker, target, body); + flake8_bugbear::rules::unused_loop_control_variable(checker, stmt_for); } if checker.enabled(Rule::IncorrectDictIterator) { - perflint::rules::incorrect_dict_iterator(checker, target, iter); + perflint::rules::incorrect_dict_iterator(checker, stmt_for); + } + if checker.enabled(Rule::YieldInForLoop) { + pyupgrade::rules::yield_in_for_loop(checker, stmt_for); } } } diff --git a/crates/ruff/src/checkers/ast/analyze/statement.rs b/crates/ruff/src/checkers/ast/analyze/statement.rs index 7518c07523..5b6892f8a7 100644 --- a/crates/ruff/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff/src/checkers/ast/analyze/statement.rs @@ -338,9 +338,6 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if checker.enabled(Rule::FStringDocstring) { flake8_bugbear::rules::f_string_docstring(checker, body); } - if checker.enabled(Rule::YieldInForLoop) { - pyupgrade::rules::yield_in_for_loop(checker, stmt); - } if let ScopeKind::Class(class_def) = checker.semantic.current_scope().kind { if checker.enabled(Rule::BuiltinAttributeShadowing) { flake8_builtins::rules::builtin_method_shadowing( @@ -1178,8 +1175,11 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { orelse, .. }) => { - if checker.any_enabled(&[Rule::UnusedLoopControlVariable, Rule::IncorrectDictIterator]) - { + if checker.any_enabled(&[ + Rule::UnusedLoopControlVariable, + Rule::IncorrectDictIterator, + Rule::YieldInForLoop, + ]) { checker.deferred.for_loops.push(checker.semantic.snapshot()); } if checker.enabled(Rule::LoopVariableOverridesIterator) { diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs b/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs index 295a9c18e4..d8b37dac1a 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs @@ -1,9 +1,9 @@ -use ruff_python_ast::{self as ast, Expr, Ranged, Stmt}; use rustc_hash::FxHashMap; use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::{self as ast, Expr, Ranged}; use ruff_python_ast::{helpers, visitor}; use crate::checkers::ast::Checker; @@ -105,16 +105,16 @@ where } /// B007 -pub(crate) fn unused_loop_control_variable(checker: &mut Checker, target: &Expr, body: &[Stmt]) { +pub(crate) fn unused_loop_control_variable(checker: &mut Checker, stmt_for: &ast::StmtFor) { let control_names = { let mut finder = NameFinder::new(); - finder.visit_expr(target); + finder.visit_expr(stmt_for.target.as_ref()); finder.names }; let used_names = { let mut finder = NameFinder::new(); - for stmt in body { + for stmt in &stmt_for.body { finder.visit_stmt(stmt); } finder.names @@ -132,9 +132,10 @@ pub(crate) fn unused_loop_control_variable(checker: &mut Checker, target: &Expr, } // Avoid fixing any variables that _may_ be used, but undetectably so. - let certainty = Certainty::from(!helpers::uses_magic_variable_access(body, |id| { - checker.semantic().is_builtin(id) - })); + let certainty = + Certainty::from(!helpers::uses_magic_variable_access(&stmt_for.body, |id| { + checker.semantic().is_builtin(id) + })); // Attempt to rename the variable by prepending an underscore, but avoid // applying the fix if doing so wouldn't actually cause us to ignore the diff --git a/crates/ruff/src/rules/perflint/rules/incorrect_dict_iterator.rs b/crates/ruff/src/rules/perflint/rules/incorrect_dict_iterator.rs index 8dd641694a..8d25a2509d 100644 --- a/crates/ruff/src/rules/perflint/rules/incorrect_dict_iterator.rs +++ b/crates/ruff/src/rules/perflint/rules/incorrect_dict_iterator.rs @@ -1,11 +1,10 @@ use std::fmt; -use ruff_python_ast as ast; -use ruff_python_ast::Ranged; -use ruff_python_ast::{Arguments, Expr}; - use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast as ast; +use ruff_python_ast::Ranged; +use ruff_python_ast::{Arguments, Expr}; use ruff_python_semantic::SemanticModel; use crate::checkers::ast::Checker; @@ -58,8 +57,8 @@ impl AlwaysAutofixableViolation for IncorrectDictIterator { } /// PERF102 -pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter: &Expr) { - let Expr::Tuple(ast::ExprTuple { elts, .. }) = target else { +pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, stmt_for: &ast::StmtFor) { + let Expr::Tuple(ast::ExprTuple { elts, .. }) = stmt_for.target.as_ref() else { return; }; let [key, value] = elts.as_slice() else { @@ -69,7 +68,7 @@ pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter func, arguments: Arguments { args, .. }, .. - }) = iter + }) = stmt_for.iter.as_ref() else { return; }; @@ -105,7 +104,7 @@ pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter let replace_attribute = Edit::range_replacement("values".to_string(), attr.range()); let replace_target = Edit::range_replacement( checker.locator().slice(value.range()).to_string(), - target.range(), + stmt_for.target.range(), ); diagnostic.set_fix(Fix::suggested_edits(replace_attribute, [replace_target])); } @@ -123,7 +122,7 @@ pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter let replace_attribute = Edit::range_replacement("keys".to_string(), attr.range()); let replace_target = Edit::range_replacement( checker.locator().slice(key.range()).to_string(), - target.range(), + stmt_for.target.range(), ); diagnostic.set_fix(Fix::suggested_edits(replace_attribute, [replace_target])); } diff --git a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs index acf96c59be..0662a0c040 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs @@ -1,12 +1,6 @@ -use rustc_hash::FxHashMap; - use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::statement_visitor::StatementVisitor; -use ruff_python_ast::visitor::Visitor; -use ruff_python_ast::{self as ast, Expr, ExprContext, Ranged, Stmt}; -use ruff_python_ast::{statement_visitor, visitor}; -use ruff_python_semantic::StatementKey; +use ruff_python_ast::{self as ast, Expr, Ranged, Stmt}; use crate::checkers::ast::Checker; use crate::registry::AsRule; @@ -46,162 +40,103 @@ impl AlwaysAutofixableViolation for YieldInForLoop { } } -/// Return `true` if the two expressions are equivalent, and consistent solely +/// UP028 +pub(crate) fn yield_in_for_loop(checker: &mut Checker, stmt_for: &ast::StmtFor) { + // Intentionally omit async contexts. + if checker.semantic().in_async_context() { + return; + } + + let ast::StmtFor { + target, + iter, + body, + orelse, + is_async: _, + range: _, + } = stmt_for; + + // If there is an else statement, don't rewrite. + if !orelse.is_empty() { + return; + } + + // If there's any logic besides a yield, don't rewrite. + let [body] = body.as_slice() else { + return; + }; + + // If the body is not a yield, don't rewrite. + let Stmt::Expr(ast::StmtExpr { value, range: _ }) = &body else { + return; + }; + let Expr::Yield(ast::ExprYield { + value: Some(value), + range: _, + }) = value.as_ref() + else { + return; + }; + + // If the target is not the same as the value, don't rewrite. For example, we should rewrite + // `for x in y: yield x` to `yield from y`, but not `for x in y: yield x + 1`. + if !is_same_expr(target, value) { + return; + } + + // If any of the bound names are used outside of the yield itself, don't rewrite. + if collect_names(value).any(|name| { + checker + .semantic() + .current_scope() + .get_all(name.id.as_str()) + .any(|binding_id| { + let binding = checker.semantic().binding(binding_id); + binding.references.iter().any(|reference_id| { + checker.semantic().reference(*reference_id).range() != name.range() + }) + }) + }) { + return; + } + + let mut diagnostic = Diagnostic::new(YieldInForLoop, stmt_for.range()); + if checker.patch(diagnostic.kind.rule()) { + let contents = checker.locator().slice(iter.range()); + let contents = format!("yield from {contents}"); + diagnostic.set_fix(Fix::suggested(Edit::range_replacement( + contents, + stmt_for.range(), + ))); + } + checker.diagnostics.push(diagnostic); +} + +/// Return `true` if the two expressions are equivalent, and both consistent solely /// of tuples and names. -fn is_same_expr(a: &Expr, b: &Expr) -> bool { - match (&a, &b) { - (Expr::Name(ast::ExprName { id: a, .. }), Expr::Name(ast::ExprName { id: b, .. })) => { - a == b +fn is_same_expr(left: &Expr, right: &Expr) -> bool { + match (&left, &right) { + (Expr::Name(left), Expr::Name(right)) => left.id == right.id, + (Expr::Tuple(left), Expr::Tuple(right)) => { + left.elts.len() == right.elts.len() + && left + .elts + .iter() + .zip(right.elts.iter()) + .all(|(left, right)| is_same_expr(left, right)) } - ( - Expr::Tuple(ast::ExprTuple { elts: a, .. }), - Expr::Tuple(ast::ExprTuple { elts: b, .. }), - ) => a.len() == b.len() && a.iter().zip(b).all(|(a, b)| is_same_expr(a, b)), _ => false, } } /// Collect all named variables in an expression consisting solely of tuples and /// names. -fn collect_names(expr: &Expr) -> Vec<&str> { - match expr { - Expr::Name(ast::ExprName { id, .. }) => vec![id], - Expr::Tuple(ast::ExprTuple { elts, .. }) => elts.iter().flat_map(collect_names).collect(), - _ => panic!("Expected: Expr::Name | Expr::Tuple"), - } -} - -#[derive(Debug)] -struct YieldFrom<'a> { - stmt: &'a Stmt, - body: &'a Stmt, - iter: &'a Expr, - names: Vec<&'a str>, -} - -#[derive(Default)] -struct YieldFromVisitor<'a> { - yields: Vec>, -} - -impl<'a> StatementVisitor<'a> for YieldFromVisitor<'a> { - fn visit_stmt(&mut self, stmt: &'a Stmt) { - match stmt { - Stmt::For(ast::StmtFor { - target, - body, - orelse, - iter, - .. - }) => { - // If there is an else statement, don't rewrite. - if !orelse.is_empty() { - return; - } - // If there's any logic besides a yield, don't rewrite. - let [body] = body.as_slice() else { - return; - }; - // If the body is not a yield, don't rewrite. - if let Stmt::Expr(ast::StmtExpr { value, range: _ }) = &body { - if let Expr::Yield(ast::ExprYield { - value: Some(value), - range: _, - }) = value.as_ref() - { - if is_same_expr(target, value) { - self.yields.push(YieldFrom { - stmt, - body, - iter, - names: collect_names(target), - }); - } - } - } - } - Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { - // Don't recurse into anything that defines a new scope. - } - _ => statement_visitor::walk_stmt(self, stmt), - } - } -} - -#[derive(Default)] -struct ReferenceVisitor<'a> { - parent: Option<&'a Stmt>, - references: FxHashMap>, -} - -impl<'a> Visitor<'a> for ReferenceVisitor<'a> { - fn visit_stmt(&mut self, stmt: &'a Stmt) { - let prev_parent = self.parent; - self.parent = Some(stmt); - visitor::walk_stmt(self, stmt); - self.parent = prev_parent; - } - - fn visit_expr(&mut self, expr: &'a Expr) { - match expr { - Expr::Name(ast::ExprName { id, ctx, range: _ }) => { - if matches!(ctx, ExprContext::Load | ExprContext::Del) { - if let Some(parent) = self.parent { - self.references - .entry(StatementKey::from(parent)) - .or_default() - .push(id); - } - } - } - _ => visitor::walk_expr(self, expr), - } - } -} - -/// UP028 -pub(crate) fn yield_in_for_loop(checker: &mut Checker, stmt: &Stmt) { - // Intentionally omit async functions. - let Stmt::FunctionDef(ast::StmtFunctionDef { - is_async: false, - body, - .. - }) = stmt - else { - return; - }; - - let yields = { - let mut visitor = YieldFromVisitor::default(); - visitor.visit_body(body); - visitor.yields - }; - - let references = { - let mut visitor = ReferenceVisitor::default(); - visitor.visit_body(body); - visitor.references - }; - - for item in yields { - // If any of the bound names are used outside of the loop, don't rewrite. - if references.iter().any(|(statement, names)| { - *statement != StatementKey::from(item.stmt) - && *statement != StatementKey::from(item.body) - && item.names.iter().any(|name| names.contains(name)) - }) { - continue; - } - - let mut diagnostic = Diagnostic::new(YieldInForLoop, item.stmt.range()); - if checker.patch(diagnostic.kind.rule()) { - let contents = checker.locator().slice(item.iter.range()); - let contents = format!("yield from {contents}"); - diagnostic.set_fix(Fix::suggested(Edit::range_replacement( - contents, - item.stmt.range(), - ))); - } - checker.diagnostics.push(diagnostic); - } +fn collect_names<'a>(expr: &'a Expr) -> Box + 'a> { + Box::new( + expr.as_name_expr().into_iter().chain( + expr.as_tuple_expr() + .into_iter() + .flat_map(|tuple| tuple.elts.iter().flat_map(collect_names)), + ), + ) }