diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index c1bf7fb8a5..d979952140 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -1404,11 +1404,7 @@ where } if stmt.is_for_stmt() { if self.enabled(Rule::ReimplementedBuiltin) { - flake8_simplify::rules::convert_for_loop_to_any_all( - self, - stmt, - self.semantic.sibling_stmt(), - ); + flake8_simplify::rules::convert_for_loop_to_any_all(self, stmt); } if self.enabled(Rule::InDictKeys) { flake8_simplify::rules::key_in_dict_for(self, target, iter); @@ -4237,21 +4233,10 @@ where flake8_pie::rules::no_unnecessary_pass(self, body); } - // Step 2: Binding - let prev_body = self.semantic.body; - let prev_body_index = self.semantic.body_index; - self.semantic.body = body; - self.semantic.body_index = 0; - // Step 3: Traversal for stmt in body { self.visit_stmt(stmt); - self.semantic.body_index += 1; } - - // Step 4: Clean-up - self.semantic.body = prev_body; - self.semantic.body_index = prev_body_index; } } diff --git a/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs b/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs index 4a30d94859..953e77b843 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs @@ -1,4 +1,4 @@ -use ruff_text_size::{TextRange, TextSize}; +use ruff_text_size::TextRange; use rustpython_parser::ast::{ self, CmpOp, Comprehension, Constant, Expr, ExprContext, Ranged, Stmt, UnaryOp, }; @@ -7,10 +7,11 @@ use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::any_over_expr; use ruff_python_ast::source_code::Generator; +use ruff_python_ast::traversal; use crate::checkers::ast::Checker; use crate::line_width::LineWidth; -use crate::registry::{AsRule, Rule}; +use crate::registry::AsRule; /// ## What it does /// Checks for `for` loops that can be replaced with a builtin function, like @@ -38,7 +39,7 @@ use crate::registry::{AsRule, Rule}; /// - [Python documentation: `all`](https://docs.python.org/3/library/functions.html#all) #[violation] pub struct ReimplementedBuiltin { - repl: String, + replacement: String, } impl Violation for ReimplementedBuiltin { @@ -46,200 +47,222 @@ impl Violation for ReimplementedBuiltin { #[derive_message_formats] fn message(&self) -> String { - let ReimplementedBuiltin { repl } = self; - format!("Use `{repl}` instead of `for` loop") + let ReimplementedBuiltin { replacement } = self; + format!("Use `{replacement}` instead of `for` loop") } fn autofix_title(&self) -> Option { - let ReimplementedBuiltin { repl } = self; - Some(format!("Replace with `{repl}`")) + let ReimplementedBuiltin { replacement } = self; + Some(format!("Replace with `{replacement}`")) } } /// SIM110, SIM111 -pub(crate) fn convert_for_loop_to_any_all( - checker: &mut Checker, - stmt: &Stmt, - sibling: Option<&Stmt>, -) { - // There are two cases to consider: +pub(crate) fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt) { + if !checker.semantic().scope().kind.is_any_function() { + return; + } + + // The `for` loop itself must consist of an `if` with a `return`. + let Some(loop_) = match_loop(stmt) else { + return; + }; + + // Afterwards, there are two cases to consider: // - `for` loop with an `else: return True` or `else: return False`. - // - `for` loop followed by `return True` or `return False` - if let Some(loop_info) = return_values_for_else(stmt) - .or_else(|| sibling.and_then(|sibling| return_values_for_siblings(stmt, sibling))) - { - // Check if loop_info.target, loop_info.iter, or loop_info.test contains `await`. - if contains_await(loop_info.target) - || contains_await(loop_info.iter) - || contains_await(loop_info.test) - { - return; - } - if loop_info.return_value && !loop_info.next_return_value { - if checker.enabled(Rule::ReimplementedBuiltin) { - let contents = return_stmt( - "any", - loop_info.test, - loop_info.target, - loop_info.iter, - checker.generator(), - ); + // - `for` loop followed by `return True` or `return False`. + let Some(terminal) = match_else_return(stmt).or_else(|| { + let parent = checker.semantic().stmt_parent()?; + let suite = traversal::suite(stmt, parent)?; + let sibling = traversal::next_sibling(stmt, suite)?; + match_sibling_return(stmt, sibling) + }) else { + return; + }; - // Don't flag if the resulting expression would exceed the maximum line length. - let line_start = checker.locator.line_start(stmt.start()); - if LineWidth::new(checker.settings.tab_size) - .add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())]) - .add_str(&contents) - > checker.settings.line_length - { - return; - } + // Check if any of the expressions contain an `await` expression. + if contains_await(loop_.target) || contains_await(loop_.iter) || contains_await(loop_.test) { + return; + } - let mut diagnostic = Diagnostic::new( - ReimplementedBuiltin { - repl: contents.clone(), - }, - TextRange::new(stmt.start(), loop_info.terminal), - ); - if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("any") { - diagnostic.set_fix(Fix::suggested(Edit::replacement( - contents, - stmt.start(), - loop_info.terminal, - ))); - } - checker.diagnostics.push(diagnostic); + match (loop_.return_value, terminal.return_value) { + // Replace with `any`. + (true, false) => { + let contents = return_stmt( + "any", + loop_.test, + loop_.target, + loop_.iter, + checker.generator(), + ); + + // Don't flag if the resulting expression would exceed the maximum line length. + let line_start = checker.locator.line_start(stmt.start()); + if LineWidth::new(checker.settings.tab_size) + .add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())]) + .add_str(&contents) + > checker.settings.line_length + { + return; } - } - if !loop_info.return_value && loop_info.next_return_value { - if checker.enabled(Rule::ReimplementedBuiltin) { - // Invert the condition. - let test = { - if let Expr::UnaryOp(ast::ExprUnaryOp { - op: UnaryOp::Not, - operand, - range: _, - }) = &loop_info.test - { - *operand.clone() - } else if let Expr::Compare(ast::ExprCompare { - left, - ops, - comparators, - range: _, - }) = &loop_info.test - { - if let ([op], [comparator]) = (ops.as_slice(), comparators.as_slice()) { - let op = match op { - CmpOp::Eq => CmpOp::NotEq, - CmpOp::NotEq => CmpOp::Eq, - CmpOp::Lt => CmpOp::GtE, - CmpOp::LtE => CmpOp::Gt, - CmpOp::Gt => CmpOp::LtE, - CmpOp::GtE => CmpOp::Lt, - CmpOp::Is => CmpOp::IsNot, - CmpOp::IsNot => CmpOp::Is, - CmpOp::In => CmpOp::NotIn, - CmpOp::NotIn => CmpOp::In, - }; - let node = ast::ExprCompare { - left: left.clone(), - ops: vec![op], - comparators: vec![comparator.clone()], - range: TextRange::default(), - }; - node.into() - } else { - let node = ast::ExprUnaryOp { - op: UnaryOp::Not, - operand: Box::new(loop_info.test.clone()), - range: TextRange::default(), - }; - node.into() - } + let mut diagnostic = Diagnostic::new( + ReimplementedBuiltin { + replacement: contents.to_string(), + }, + TextRange::new(stmt.start(), terminal.stmt.end()), + ); + if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("any") { + diagnostic.set_fix(Fix::suggested(Edit::replacement( + contents, + stmt.start(), + terminal.stmt.end(), + ))); + } + checker.diagnostics.push(diagnostic); + } + // Replace with `all`. + (false, true) => { + // Invert the condition. + let test = { + if let Expr::UnaryOp(ast::ExprUnaryOp { + op: UnaryOp::Not, + operand, + range: _, + }) = &loop_.test + { + *operand.clone() + } else if let Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + range: _, + }) = &loop_.test + { + if let ([op], [comparator]) = (ops.as_slice(), comparators.as_slice()) { + let op = match op { + CmpOp::Eq => CmpOp::NotEq, + CmpOp::NotEq => CmpOp::Eq, + CmpOp::Lt => CmpOp::GtE, + CmpOp::LtE => CmpOp::Gt, + CmpOp::Gt => CmpOp::LtE, + CmpOp::GtE => CmpOp::Lt, + CmpOp::Is => CmpOp::IsNot, + CmpOp::IsNot => CmpOp::Is, + CmpOp::In => CmpOp::NotIn, + CmpOp::NotIn => CmpOp::In, + }; + let node = ast::ExprCompare { + left: left.clone(), + ops: vec![op], + comparators: vec![comparator.clone()], + range: TextRange::default(), + }; + node.into() } else { let node = ast::ExprUnaryOp { op: UnaryOp::Not, - operand: Box::new(loop_info.test.clone()), + operand: Box::new(loop_.test.clone()), range: TextRange::default(), }; node.into() } - }; - let contents = return_stmt( - "all", - &test, - loop_info.target, - loop_info.iter, - checker.generator(), - ); - - // Don't flag if the resulting expression would exceed the maximum line length. - let line_start = checker.locator.line_start(stmt.start()); - if LineWidth::new(checker.settings.tab_size) - .add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())]) - .add_str(&contents) - > checker.settings.line_length - { - return; + } else { + let node = ast::ExprUnaryOp { + op: UnaryOp::Not, + operand: Box::new(loop_.test.clone()), + range: TextRange::default(), + }; + node.into() } + }; + let contents = return_stmt("all", &test, loop_.target, loop_.iter, checker.generator()); - let mut diagnostic = Diagnostic::new( - ReimplementedBuiltin { - repl: contents.clone(), - }, - TextRange::new(stmt.start(), loop_info.terminal), - ); - if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("all") { - diagnostic.set_fix(Fix::suggested(Edit::replacement( - contents, - stmt.start(), - loop_info.terminal, - ))); - } - checker.diagnostics.push(diagnostic); + // Don't flag if the resulting expression would exceed the maximum line length. + let line_start = checker.locator.line_start(stmt.start()); + if LineWidth::new(checker.settings.tab_size) + .add_str(&checker.locator.contents()[TextRange::new(line_start, stmt.start())]) + .add_str(&contents) + > checker.settings.line_length + { + return; } + + let mut diagnostic = Diagnostic::new( + ReimplementedBuiltin { + replacement: contents.to_string(), + }, + TextRange::new(stmt.start(), terminal.stmt.end()), + ); + if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("all") { + diagnostic.set_fix(Fix::suggested(Edit::replacement( + contents, + stmt.start(), + terminal.stmt.end(), + ))); + } + checker.diagnostics.push(diagnostic); } + _ => {} } } +/// Represents a `for` loop with a conditional `return`, like: +/// ```python +/// for x in y: +/// if x == 0: +/// return True +/// ``` +#[derive(Debug)] struct Loop<'a> { + /// The `return` value of the loop. return_value: bool, - next_return_value: bool, + /// The test condition in the loop. test: &'a Expr, + /// The target of the loop. target: &'a Expr, + /// The iterator of the loop. iter: &'a Expr, - terminal: TextSize, } -/// Extract the returned boolean values a `Stmt::For` with an `else` body. -fn return_values_for_else(stmt: &Stmt) -> Option { +/// Represents a `return` statement following a `for` loop, like: +/// ```python +/// for x in y: +/// if x == 0: +/// return True +/// return False +/// ``` +/// +/// Or: +/// ```python +/// for x in y: +/// if x == 0: +/// return True +/// else: +/// return False +/// ``` +#[derive(Debug)] +struct Terminal<'a> { + return_value: bool, + stmt: &'a Stmt, +} + +fn match_loop(stmt: &Stmt) -> Option { let Stmt::For(ast::StmtFor { - body, - target, - iter, - orelse, - .. + body, target, iter, .. }) = stmt else { return None; }; - // The loop itself should contain a single `if` statement, with an `else` - // containing a single `return True` or `return False`. - if body.len() != 1 { - return None; - } - if orelse.len() != 1 { - return None; - } - let Stmt::If(ast::StmtIf { + // The loop itself should contain a single `if` statement, with a single `return` statement in + // the body. + let [Stmt::If(ast::StmtIf { body: nested_body, test: nested_test, elif_else_clauses: nested_elif_else_clauses, range: _, - }) = &body[0] + })] = body.as_slice() else { return None; }; @@ -263,15 +286,35 @@ fn return_values_for_else(stmt: &Stmt) -> Option { return None; }; - // The `else` block has to contain a single `return True` or `return False`. - let Stmt::Return(ast::StmtReturn { - value: next_value, - range: _, - }) = &orelse[0] - else { + Some(Loop { + return_value: *value, + test: nested_test, + target, + iter, + }) +} + +/// If a `Stmt::For` contains an `else` with a single boolean `return`, return the [`Terminal`] +/// representing that `return`. +/// +/// For example, matches the `return` in: +/// ```python +/// for x in y: +/// if x == 0: +/// return True +/// return False +/// ``` +fn match_else_return(stmt: &Stmt) -> Option { + let Stmt::For(ast::StmtFor { orelse, .. }) = stmt else { return None; }; - let Some(next_value) = next_value else { + + // The `else` block has to contain a single `return True` or `return False`. + let [Stmt::Return(ast::StmtReturn { + value: Some(next_value), + range: _, + })] = orelse.as_slice() + else { return None; }; let Expr::Constant(ast::ExprConstant { @@ -282,78 +325,41 @@ fn return_values_for_else(stmt: &Stmt) -> Option { return None; }; - Some(Loop { - return_value: *value, - next_return_value: *next_value, - test: nested_test, - target, - iter, - terminal: stmt.end(), + Some(Terminal { + return_value: *next_value, + stmt, }) } -/// Extract the returned boolean values from subsequent `Stmt::For` and -/// `Stmt::Return` statements, or `None`. -fn return_values_for_siblings<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option> { - let Stmt::For(ast::StmtFor { - body, - target, - iter, - orelse, - .. - }) = stmt - else { +/// If a `Stmt::For` is followed by a boolean `return`, return the [`Terminal`] representing that +/// `return`. +/// +/// For example, matches the `return` in: +/// ```python +/// for x in y: +/// if x == 0: +/// return True +/// else: +/// return False +/// ``` +fn match_sibling_return<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option> { + let Stmt::For(ast::StmtFor { orelse, .. }) = stmt else { return None; }; - // The loop itself should contain a single `if` statement, with a single `return - // True` or `return False`. - if body.len() != 1 { - return None; - } + // The loop itself shouldn't have an `else` block. if !orelse.is_empty() { return None; } - let Stmt::If(ast::StmtIf { - body: nested_body, - test: nested_test, - elif_else_clauses: nested_elif_else_clauses, - range: _, - }) = &body[0] - else { - return None; - }; - if nested_body.len() != 1 { - return None; - } - if !nested_elif_else_clauses.is_empty() { - return None; - } - let Stmt::Return(ast::StmtReturn { value, range: _ }) = &nested_body[0] else { - return None; - }; - let Some(value) = value else { - return None; - }; - let Expr::Constant(ast::ExprConstant { - value: Constant::Bool(value), - .. - }) = value.as_ref() - else { - return None; - }; // The next statement has to be a `return True` or `return False`. let Stmt::Return(ast::StmtReturn { - value: next_value, + value: Some(next_value), range: _, }) = &sibling else { return None; }; - let Some(next_value) = next_value else { - return None; - }; let Expr::Constant(ast::ExprConstant { value: Constant::Bool(next_value), .. @@ -362,13 +368,9 @@ fn return_values_for_siblings<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> { + match parent { + Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => Some(body), + Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) => Some(body), + Stmt::ClassDef(ast::StmtClassDef { body, .. }) => Some(body), + Stmt::For(ast::StmtFor { body, orelse, .. }) => { + if body.contains(stmt) { + Some(body) + } else if orelse.contains(stmt) { + Some(orelse) + } else { + None + } + } + Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) => { + if body.contains(stmt) { + Some(body) + } else if orelse.contains(stmt) { + Some(orelse) + } else { + None + } + } + Stmt::While(ast::StmtWhile { body, orelse, .. }) => { + if body.contains(stmt) { + Some(body) + } else if orelse.contains(stmt) { + Some(orelse) + } else { + None + } + } + Stmt::If(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => { + if body.contains(stmt) { + Some(body) + } else { + elif_else_clauses + .iter() + .map(|elif_else_clause| &elif_else_clause.body) + .find(|body| body.contains(stmt)) + } + } + Stmt::With(ast::StmtWith { body, .. }) => Some(body), + Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => Some(body), + Stmt::Match(ast::StmtMatch { cases, .. }) => cases + .iter() + .map(|case| &case.body) + .find(|body| body.contains(stmt)), + Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + if body.contains(stmt) { + Some(body) + } else if orelse.contains(stmt) { + Some(orelse) + } else if finalbody.contains(stmt) { + Some(finalbody) + } else { + handlers + .iter() + .filter_map(ExceptHandler::as_except_handler) + .map(|handler| &handler.body) + .find(|body| body.contains(stmt)) + } + } + Stmt::TryStar(ast::StmtTryStar { + body, + handlers, + orelse, + finalbody, + .. + }) => { + if body.contains(stmt) { + Some(body) + } else if orelse.contains(stmt) { + Some(orelse) + } else if finalbody.contains(stmt) { + Some(finalbody) + } else { + handlers + .iter() + .filter_map(ExceptHandler::as_except_handler) + .map(|handler| &handler.body) + .find(|body| body.contains(stmt)) + } + } + _ => None, + } +} + +/// Given a [`Stmt`] and its containing [`Suite`], return the next [`Stmt`] in the [`Suite`]. +pub fn next_sibling<'a>(stmt: &'a Stmt, suite: &'a Suite) -> Option<&'a Stmt> { + let mut iter = suite.iter(); + while let Some(sibling) = iter.next() { + if sibling == stmt { + return iter.next(); + } + } + None +} diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index 53138dbf5d..0106dfcfc0 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -108,10 +108,6 @@ pub struct SemanticModel<'a> { /// by way of the `global x` statement. rebinding_scopes: HashMap, BuildNoHashHasher>, - /// Body iteration; used to peek at siblings. - pub body: &'a [Stmt], - pub body_index: usize, - /// Flags for the semantic model. pub flags: SemanticModelFlags, @@ -137,8 +133,6 @@ impl<'a> SemanticModel<'a> { shadowed_bindings: IntMap::default(), delayed_annotations: IntMap::default(), rebinding_scopes: IntMap::default(), - body: &[], - body_index: 0, flags: SemanticModelFlags::new(path), handled_exceptions: Vec::default(), } @@ -757,11 +751,6 @@ impl<'a> SemanticModel<'a> { self.exprs.iter().rev().skip(1) } - /// Return the `Stmt` that immediately follows the current `Stmt`, if any. - pub fn sibling_stmt(&self) -> Option<&'a Stmt> { - self.body.get(self.body_index + 1) - } - /// Returns a reference to the global scope pub fn global_scope(&self) -> &Scope<'a> { self.scopes.global()