Rewrite `yield-in-for-loop` to avoid recursing over body (#6692)

## Summary

This is much simpler and avoids (1) multiple passes over the entire
function body, (2) requiring the rule to do its own binding tracking (we
can just use the semantic model), and (3) a usage of `StatementKey`.

In general, where we can, we should try to remove these kinds of custom
visitors that track name references, and instead rely on the semantic
model.

## Test Plan

`cargo test`
This commit is contained in:
Charlie Marsh 2023-08-19 11:25:29 -04:00 committed by GitHub
parent 59e533047a
commit 3849fa0cf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 121 additions and 186 deletions

View File

@ -1,8 +1,8 @@
use ruff_python_ast::{self as ast, Stmt};
use ruff_python_ast::Stmt;
use crate::checkers::ast::Checker;
use crate::codes::Rule;
use crate::rules::{flake8_bugbear, perflint};
use crate::rules::{flake8_bugbear, perflint, pyupgrade};
/// Run lint rules over all deferred for-loops in the [`SemanticModel`].
pub(crate) fn deferred_for_loops(checker: &mut Checker) {
@ -11,18 +11,18 @@ pub(crate) fn deferred_for_loops(checker: &mut Checker) {
for snapshot in for_loops {
checker.semantic.restore(snapshot);
let Stmt::For(ast::StmtFor {
target, iter, body, ..
}) = checker.semantic.current_statement()
else {
let Stmt::For(stmt_for) = checker.semantic.current_statement() else {
unreachable!("Expected Stmt::For");
};
if checker.enabled(Rule::UnusedLoopControlVariable) {
flake8_bugbear::rules::unused_loop_control_variable(checker, target, body);
flake8_bugbear::rules::unused_loop_control_variable(checker, stmt_for);
}
if checker.enabled(Rule::IncorrectDictIterator) {
perflint::rules::incorrect_dict_iterator(checker, target, iter);
perflint::rules::incorrect_dict_iterator(checker, stmt_for);
}
if checker.enabled(Rule::YieldInForLoop) {
pyupgrade::rules::yield_in_for_loop(checker, stmt_for);
}
}
}

View File

@ -338,9 +338,6 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
if checker.enabled(Rule::FStringDocstring) {
flake8_bugbear::rules::f_string_docstring(checker, body);
}
if checker.enabled(Rule::YieldInForLoop) {
pyupgrade::rules::yield_in_for_loop(checker, stmt);
}
if let ScopeKind::Class(class_def) = checker.semantic.current_scope().kind {
if checker.enabled(Rule::BuiltinAttributeShadowing) {
flake8_builtins::rules::builtin_method_shadowing(
@ -1178,8 +1175,11 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
orelse,
..
}) => {
if checker.any_enabled(&[Rule::UnusedLoopControlVariable, Rule::IncorrectDictIterator])
{
if checker.any_enabled(&[
Rule::UnusedLoopControlVariable,
Rule::IncorrectDictIterator,
Rule::YieldInForLoop,
]) {
checker.deferred.for_loops.push(checker.semantic.snapshot());
}
if checker.enabled(Rule::LoopVariableOverridesIterator) {

View File

@ -1,9 +1,9 @@
use ruff_python_ast::{self as ast, Expr, Ranged, Stmt};
use rustc_hash::FxHashMap;
use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::visitor::Visitor;
use ruff_python_ast::{self as ast, Expr, Ranged};
use ruff_python_ast::{helpers, visitor};
use crate::checkers::ast::Checker;
@ -105,16 +105,16 @@ where
}
/// B007
pub(crate) fn unused_loop_control_variable(checker: &mut Checker, target: &Expr, body: &[Stmt]) {
pub(crate) fn unused_loop_control_variable(checker: &mut Checker, stmt_for: &ast::StmtFor) {
let control_names = {
let mut finder = NameFinder::new();
finder.visit_expr(target);
finder.visit_expr(stmt_for.target.as_ref());
finder.names
};
let used_names = {
let mut finder = NameFinder::new();
for stmt in body {
for stmt in &stmt_for.body {
finder.visit_stmt(stmt);
}
finder.names
@ -132,9 +132,10 @@ pub(crate) fn unused_loop_control_variable(checker: &mut Checker, target: &Expr,
}
// Avoid fixing any variables that _may_ be used, but undetectably so.
let certainty = Certainty::from(!helpers::uses_magic_variable_access(body, |id| {
checker.semantic().is_builtin(id)
}));
let certainty =
Certainty::from(!helpers::uses_magic_variable_access(&stmt_for.body, |id| {
checker.semantic().is_builtin(id)
}));
// Attempt to rename the variable by prepending an underscore, but avoid
// applying the fix if doing so wouldn't actually cause us to ignore the

View File

@ -1,11 +1,10 @@
use std::fmt;
use ruff_python_ast as ast;
use ruff_python_ast::Ranged;
use ruff_python_ast::{Arguments, Expr};
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast as ast;
use ruff_python_ast::Ranged;
use ruff_python_ast::{Arguments, Expr};
use ruff_python_semantic::SemanticModel;
use crate::checkers::ast::Checker;
@ -58,8 +57,8 @@ impl AlwaysAutofixableViolation for IncorrectDictIterator {
}
/// PERF102
pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter: &Expr) {
let Expr::Tuple(ast::ExprTuple { elts, .. }) = target else {
pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, stmt_for: &ast::StmtFor) {
let Expr::Tuple(ast::ExprTuple { elts, .. }) = stmt_for.target.as_ref() else {
return;
};
let [key, value] = elts.as_slice() else {
@ -69,7 +68,7 @@ pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter
func,
arguments: Arguments { args, .. },
..
}) = iter
}) = stmt_for.iter.as_ref()
else {
return;
};
@ -105,7 +104,7 @@ pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter
let replace_attribute = Edit::range_replacement("values".to_string(), attr.range());
let replace_target = Edit::range_replacement(
checker.locator().slice(value.range()).to_string(),
target.range(),
stmt_for.target.range(),
);
diagnostic.set_fix(Fix::suggested_edits(replace_attribute, [replace_target]));
}
@ -123,7 +122,7 @@ pub(crate) fn incorrect_dict_iterator(checker: &mut Checker, target: &Expr, iter
let replace_attribute = Edit::range_replacement("keys".to_string(), attr.range());
let replace_target = Edit::range_replacement(
checker.locator().slice(key.range()).to_string(),
target.range(),
stmt_for.target.range(),
);
diagnostic.set_fix(Fix::suggested_edits(replace_attribute, [replace_target]));
}

View File

@ -1,12 +1,6 @@
use rustc_hash::FxHashMap;
use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::statement_visitor::StatementVisitor;
use ruff_python_ast::visitor::Visitor;
use ruff_python_ast::{self as ast, Expr, ExprContext, Ranged, Stmt};
use ruff_python_ast::{statement_visitor, visitor};
use ruff_python_semantic::StatementKey;
use ruff_python_ast::{self as ast, Expr, Ranged, Stmt};
use crate::checkers::ast::Checker;
use crate::registry::AsRule;
@ -46,162 +40,103 @@ impl AlwaysAutofixableViolation for YieldInForLoop {
}
}
/// Return `true` if the two expressions are equivalent, and consistent solely
/// UP028
pub(crate) fn yield_in_for_loop(checker: &mut Checker, stmt_for: &ast::StmtFor) {
// Intentionally omit async contexts.
if checker.semantic().in_async_context() {
return;
}
let ast::StmtFor {
target,
iter,
body,
orelse,
is_async: _,
range: _,
} = stmt_for;
// If there is an else statement, don't rewrite.
if !orelse.is_empty() {
return;
}
// If there's any logic besides a yield, don't rewrite.
let [body] = body.as_slice() else {
return;
};
// If the body is not a yield, don't rewrite.
let Stmt::Expr(ast::StmtExpr { value, range: _ }) = &body else {
return;
};
let Expr::Yield(ast::ExprYield {
value: Some(value),
range: _,
}) = value.as_ref()
else {
return;
};
// If the target is not the same as the value, don't rewrite. For example, we should rewrite
// `for x in y: yield x` to `yield from y`, but not `for x in y: yield x + 1`.
if !is_same_expr(target, value) {
return;
}
// If any of the bound names are used outside of the yield itself, don't rewrite.
if collect_names(value).any(|name| {
checker
.semantic()
.current_scope()
.get_all(name.id.as_str())
.any(|binding_id| {
let binding = checker.semantic().binding(binding_id);
binding.references.iter().any(|reference_id| {
checker.semantic().reference(*reference_id).range() != name.range()
})
})
}) {
return;
}
let mut diagnostic = Diagnostic::new(YieldInForLoop, stmt_for.range());
if checker.patch(diagnostic.kind.rule()) {
let contents = checker.locator().slice(iter.range());
let contents = format!("yield from {contents}");
diagnostic.set_fix(Fix::suggested(Edit::range_replacement(
contents,
stmt_for.range(),
)));
}
checker.diagnostics.push(diagnostic);
}
/// Return `true` if the two expressions are equivalent, and both consistent solely
/// of tuples and names.
fn is_same_expr(a: &Expr, b: &Expr) -> bool {
match (&a, &b) {
(Expr::Name(ast::ExprName { id: a, .. }), Expr::Name(ast::ExprName { id: b, .. })) => {
a == b
fn is_same_expr(left: &Expr, right: &Expr) -> bool {
match (&left, &right) {
(Expr::Name(left), Expr::Name(right)) => left.id == right.id,
(Expr::Tuple(left), Expr::Tuple(right)) => {
left.elts.len() == right.elts.len()
&& left
.elts
.iter()
.zip(right.elts.iter())
.all(|(left, right)| is_same_expr(left, right))
}
(
Expr::Tuple(ast::ExprTuple { elts: a, .. }),
Expr::Tuple(ast::ExprTuple { elts: b, .. }),
) => a.len() == b.len() && a.iter().zip(b).all(|(a, b)| is_same_expr(a, b)),
_ => false,
}
}
/// Collect all named variables in an expression consisting solely of tuples and
/// names.
fn collect_names(expr: &Expr) -> Vec<&str> {
match expr {
Expr::Name(ast::ExprName { id, .. }) => vec![id],
Expr::Tuple(ast::ExprTuple { elts, .. }) => elts.iter().flat_map(collect_names).collect(),
_ => panic!("Expected: Expr::Name | Expr::Tuple"),
}
}
#[derive(Debug)]
struct YieldFrom<'a> {
stmt: &'a Stmt,
body: &'a Stmt,
iter: &'a Expr,
names: Vec<&'a str>,
}
#[derive(Default)]
struct YieldFromVisitor<'a> {
yields: Vec<YieldFrom<'a>>,
}
impl<'a> StatementVisitor<'a> for YieldFromVisitor<'a> {
fn visit_stmt(&mut self, stmt: &'a Stmt) {
match stmt {
Stmt::For(ast::StmtFor {
target,
body,
orelse,
iter,
..
}) => {
// If there is an else statement, don't rewrite.
if !orelse.is_empty() {
return;
}
// If there's any logic besides a yield, don't rewrite.
let [body] = body.as_slice() else {
return;
};
// If the body is not a yield, don't rewrite.
if let Stmt::Expr(ast::StmtExpr { value, range: _ }) = &body {
if let Expr::Yield(ast::ExprYield {
value: Some(value),
range: _,
}) = value.as_ref()
{
if is_same_expr(target, value) {
self.yields.push(YieldFrom {
stmt,
body,
iter,
names: collect_names(target),
});
}
}
}
}
Stmt::FunctionDef(_) | Stmt::ClassDef(_) => {
// Don't recurse into anything that defines a new scope.
}
_ => statement_visitor::walk_stmt(self, stmt),
}
}
}
#[derive(Default)]
struct ReferenceVisitor<'a> {
parent: Option<&'a Stmt>,
references: FxHashMap<StatementKey, Vec<&'a str>>,
}
impl<'a> Visitor<'a> for ReferenceVisitor<'a> {
fn visit_stmt(&mut self, stmt: &'a Stmt) {
let prev_parent = self.parent;
self.parent = Some(stmt);
visitor::walk_stmt(self, stmt);
self.parent = prev_parent;
}
fn visit_expr(&mut self, expr: &'a Expr) {
match expr {
Expr::Name(ast::ExprName { id, ctx, range: _ }) => {
if matches!(ctx, ExprContext::Load | ExprContext::Del) {
if let Some(parent) = self.parent {
self.references
.entry(StatementKey::from(parent))
.or_default()
.push(id);
}
}
}
_ => visitor::walk_expr(self, expr),
}
}
}
/// UP028
pub(crate) fn yield_in_for_loop(checker: &mut Checker, stmt: &Stmt) {
// Intentionally omit async functions.
let Stmt::FunctionDef(ast::StmtFunctionDef {
is_async: false,
body,
..
}) = stmt
else {
return;
};
let yields = {
let mut visitor = YieldFromVisitor::default();
visitor.visit_body(body);
visitor.yields
};
let references = {
let mut visitor = ReferenceVisitor::default();
visitor.visit_body(body);
visitor.references
};
for item in yields {
// If any of the bound names are used outside of the loop, don't rewrite.
if references.iter().any(|(statement, names)| {
*statement != StatementKey::from(item.stmt)
&& *statement != StatementKey::from(item.body)
&& item.names.iter().any(|name| names.contains(name))
}) {
continue;
}
let mut diagnostic = Diagnostic::new(YieldInForLoop, item.stmt.range());
if checker.patch(diagnostic.kind.rule()) {
let contents = checker.locator().slice(item.iter.range());
let contents = format!("yield from {contents}");
diagnostic.set_fix(Fix::suggested(Edit::range_replacement(
contents,
item.stmt.range(),
)));
}
checker.diagnostics.push(diagnostic);
}
fn collect_names<'a>(expr: &'a Expr) -> Box<dyn Iterator<Item = &ast::ExprName> + 'a> {
Box::new(
expr.as_name_expr().into_iter().chain(
expr.as_tuple_expr()
.into_iter()
.flat_map(|tuple| tuple.elts.iter().flat_map(collect_names)),
),
)
}