diff --git a/crates/ruff/resources/test/fixtures/flake8_simplify/SIM114.py b/crates/ruff/resources/test/fixtures/flake8_simplify/SIM114.py index c66dd33ad2..27a9dc0c41 100644 --- a/crates/ruff/resources/test/fixtures/flake8_simplify/SIM114.py +++ b/crates/ruff/resources/test/fixtures/flake8_simplify/SIM114.py @@ -1,4 +1,4 @@ -# These SHOULD change +# Errors if a: b elif c: @@ -52,9 +52,23 @@ if ( and k == 14 ): pass -elif 1 == 2: pass +elif 1 == 2: + pass -# These SHOULD NOT change +failures = errors = skipped = disabled = 0 +if result.eofs == "O": + pass +elif result.eofs == "S": + skipped = 1 +elif result.eofs == "F": + failures = 1 +elif result.eofs == "E": + errors = 1 +else: + errors = 1 + + +# OK def complicated_calc(*arg, **kwargs): return 42 diff --git a/crates/ruff/src/ast/comparable.rs b/crates/ruff/src/ast/comparable.rs index 6d6e6447cf..a6b15bc6ef 100644 --- a/crates/ruff/src/ast/comparable.rs +++ b/crates/ruff/src/ast/comparable.rs @@ -3,8 +3,9 @@ use num_bigint::BigInt; use rustpython_parser::ast::{ - Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Expr, ExprContext, ExprKind, Keyword, - Operator, Unaryop, + Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, Excepthandler, + ExcepthandlerKind, Expr, ExprContext, ExprKind, Keyword, Operator, Stmt, StmtKind, Unaryop, + Withitem, }; #[derive(Debug, PartialEq, Eq, Hash)] @@ -126,6 +127,36 @@ impl From<&Cmpop> for ComparableCmpop { } } +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ComparableAlias<'a> { + pub name: &'a str, + pub asname: Option<&'a str>, +} + +impl<'a> From<&'a Alias> for ComparableAlias<'a> { + fn from(alias: &'a Alias) -> Self { + Self { + name: &alias.node.name, + asname: alias.node.asname.as_deref(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ComparableWithitem<'a> { + pub context_expr: ComparableExpr<'a>, + pub optional_vars: Option>, +} + +impl<'a> From<&'a Withitem> for ComparableWithitem<'a> { + fn from(withitem: &'a Withitem) -> Self { + Self { + context_expr: (&withitem.context_expr).into(), + optional_vars: withitem.optional_vars.as_ref().map(Into::into), + } + } +} + #[derive(Debug, PartialEq, Eq, Hash)] pub enum ComparableConstant<'a> { None, @@ -147,9 +178,7 @@ impl<'a> From<&'a Constant> for ComparableConstant<'a> { Constant::Str(value) => Self::Str(value), Constant::Bytes(value) => Self::Bytes(value), Constant::Int(value) => Self::Int(value), - Constant::Tuple(value) => { - Self::Tuple(value.iter().map(std::convert::Into::into).collect()) - } + Constant::Tuple(value) => Self::Tuple(value.iter().map(Into::into).collect()), Constant::Float(value) => Self::Float(value.to_bits()), Constant::Complex { real, imag } => Self::Complex { real: real.to_bits(), @@ -174,37 +203,23 @@ pub struct ComparableArguments<'a> { impl<'a> From<&'a Arguments> for ComparableArguments<'a> { fn from(arguments: &'a Arguments) -> Self { Self { - posonlyargs: arguments - .posonlyargs - .iter() - .map(std::convert::Into::into) - .collect(), - args: arguments - .args - .iter() - .map(std::convert::Into::into) - .collect(), - vararg: arguments.vararg.as_ref().map(std::convert::Into::into), - kwonlyargs: arguments - .kwonlyargs - .iter() - .map(std::convert::Into::into) - .collect(), - kw_defaults: arguments - .kw_defaults - .iter() - .map(std::convert::Into::into) - .collect(), - kwarg: arguments.vararg.as_ref().map(std::convert::Into::into), - defaults: arguments - .defaults - .iter() - .map(std::convert::Into::into) - .collect(), + posonlyargs: arguments.posonlyargs.iter().map(Into::into).collect(), + args: arguments.args.iter().map(Into::into).collect(), + vararg: arguments.vararg.as_ref().map(Into::into), + kwonlyargs: arguments.kwonlyargs.iter().map(Into::into).collect(), + kw_defaults: arguments.kw_defaults.iter().map(Into::into).collect(), + kwarg: arguments.vararg.as_ref().map(Into::into), + defaults: arguments.defaults.iter().map(Into::into).collect(), } } } +impl<'a> From<&'a Box> for ComparableArguments<'a> { + fn from(arguments: &'a Box) -> Self { + (&**arguments).into() + } +} + impl<'a> From<&'a Box> for ComparableArg<'a> { fn from(arg: &'a Box) -> Self { (&**arg).into() @@ -222,7 +237,7 @@ impl<'a> From<&'a Arg> for ComparableArg<'a> { fn from(arg: &'a Arg) -> Self { Self { arg: &arg.node.arg, - annotation: arg.node.annotation.as_ref().map(std::convert::Into::into), + annotation: arg.node.annotation.as_ref().map(Into::into), type_comment: arg.node.type_comment.as_deref(), } } @@ -256,16 +271,32 @@ impl<'a> From<&'a Comprehension> for ComparableComprehension<'a> { Self { target: (&comprehension.target).into(), iter: (&comprehension.iter).into(), - ifs: comprehension - .ifs - .iter() - .map(std::convert::Into::into) - .collect(), + ifs: comprehension.ifs.iter().map(Into::into).collect(), is_async: &comprehension.is_async, } } } +#[derive(Debug, PartialEq, Eq, Hash)] +pub enum ComparableExcepthandler<'a> { + ExceptHandler { + type_: Option>, + name: Option<&'a str>, + body: Vec>, + }, +} + +impl<'a> From<&'a Excepthandler> for ComparableExcepthandler<'a> { + fn from(excepthandler: &'a Excepthandler) -> Self { + let ExcepthandlerKind::ExceptHandler { type_, name, body } = &excepthandler.node; + Self::ExceptHandler { + type_: type_.as_ref().map(Into::into), + name: name.as_deref(), + body: body.iter().map(Into::into).collect(), + } + } +} + #[derive(Debug, PartialEq, Eq, Hash)] pub enum ComparableExpr<'a> { BoolOp { @@ -399,7 +430,7 @@ impl<'a> From<&'a Expr> for ComparableExpr<'a> { match &expr.node { ExprKind::BoolOp { op, values } => Self::BoolOp { op: op.into(), - values: values.iter().map(std::convert::Into::into).collect(), + values: values.iter().map(Into::into).collect(), }, ExprKind::NamedExpr { target, value } => Self::NamedExpr { target: target.into(), @@ -426,20 +457,20 @@ impl<'a> From<&'a Expr> for ComparableExpr<'a> { ExprKind::Dict { keys, values } => Self::Dict { keys: keys .iter() - .map(|expr| expr.as_ref().map(std::convert::Into::into)) + .map(|expr| expr.as_ref().map(Into::into)) .collect(), - values: values.iter().map(std::convert::Into::into).collect(), + values: values.iter().map(Into::into).collect(), }, ExprKind::Set { elts } => Self::Set { - elts: elts.iter().map(std::convert::Into::into).collect(), + elts: elts.iter().map(Into::into).collect(), }, ExprKind::ListComp { elt, generators } => Self::ListComp { elt: elt.into(), - generators: generators.iter().map(std::convert::Into::into).collect(), + generators: generators.iter().map(Into::into).collect(), }, ExprKind::SetComp { elt, generators } => Self::SetComp { elt: elt.into(), - generators: generators.iter().map(std::convert::Into::into).collect(), + generators: generators.iter().map(Into::into).collect(), }, ExprKind::DictComp { key, @@ -448,17 +479,17 @@ impl<'a> From<&'a Expr> for ComparableExpr<'a> { } => Self::DictComp { key: key.into(), value: value.into(), - generators: generators.iter().map(std::convert::Into::into).collect(), + generators: generators.iter().map(Into::into).collect(), }, ExprKind::GeneratorExp { elt, generators } => Self::GeneratorExp { elt: elt.into(), - generators: generators.iter().map(std::convert::Into::into).collect(), + generators: generators.iter().map(Into::into).collect(), }, ExprKind::Await { value } => Self::Await { value: value.into(), }, ExprKind::Yield { value } => Self::Yield { - value: value.as_ref().map(std::convert::Into::into), + value: value.as_ref().map(Into::into), }, ExprKind::YieldFrom { value } => Self::YieldFrom { value: value.into(), @@ -469,8 +500,8 @@ impl<'a> From<&'a Expr> for ComparableExpr<'a> { comparators, } => Self::Compare { left: left.into(), - ops: ops.iter().map(std::convert::Into::into).collect(), - comparators: comparators.iter().map(std::convert::Into::into).collect(), + ops: ops.iter().map(Into::into).collect(), + comparators: comparators.iter().map(Into::into).collect(), }, ExprKind::Call { func, @@ -478,8 +509,8 @@ impl<'a> From<&'a Expr> for ComparableExpr<'a> { keywords, } => Self::Call { func: func.into(), - args: args.iter().map(std::convert::Into::into).collect(), - keywords: keywords.iter().map(std::convert::Into::into).collect(), + args: args.iter().map(Into::into).collect(), + keywords: keywords.iter().map(Into::into).collect(), }, ExprKind::FormattedValue { value, @@ -488,10 +519,10 @@ impl<'a> From<&'a Expr> for ComparableExpr<'a> { } => Self::FormattedValue { value: value.into(), conversion, - format_spec: format_spec.as_ref().map(std::convert::Into::into), + format_spec: format_spec.as_ref().map(Into::into), }, ExprKind::JoinedStr { values } => Self::JoinedStr { - values: values.iter().map(std::convert::Into::into).collect(), + values: values.iter().map(Into::into).collect(), }, ExprKind::Constant { value, kind } => Self::Constant { value: value.into(), @@ -516,18 +547,314 @@ impl<'a> From<&'a Expr> for ComparableExpr<'a> { ctx: ctx.into(), }, ExprKind::List { elts, ctx } => Self::List { - elts: elts.iter().map(std::convert::Into::into).collect(), + elts: elts.iter().map(Into::into).collect(), ctx: ctx.into(), }, ExprKind::Tuple { elts, ctx } => Self::Tuple { - elts: elts.iter().map(std::convert::Into::into).collect(), + elts: elts.iter().map(Into::into).collect(), ctx: ctx.into(), }, ExprKind::Slice { lower, upper, step } => Self::Slice { - lower: lower.as_ref().map(std::convert::Into::into), - upper: upper.as_ref().map(std::convert::Into::into), - step: step.as_ref().map(std::convert::Into::into), + lower: lower.as_ref().map(Into::into), + upper: upper.as_ref().map(Into::into), + step: step.as_ref().map(Into::into), }, } } } + +#[derive(Debug, PartialEq, Eq, Hash)] +pub enum ComparableStmt<'a> { + FunctionDef { + name: &'a str, + args: ComparableArguments<'a>, + body: Vec>, + decorator_list: Vec>, + returns: Option>, + type_comment: Option<&'a str>, + }, + AsyncFunctionDef { + name: &'a str, + args: ComparableArguments<'a>, + body: Vec>, + decorator_list: Vec>, + returns: Option>, + type_comment: Option<&'a str>, + }, + ClassDef { + name: &'a str, + bases: Vec>, + keywords: Vec>, + body: Vec>, + decorator_list: Vec>, + }, + Return { + value: Option>, + }, + Delete { + targets: Vec>, + }, + Assign { + targets: Vec>, + value: ComparableExpr<'a>, + type_comment: Option<&'a str>, + }, + AugAssign { + target: ComparableExpr<'a>, + op: ComparableOperator, + value: ComparableExpr<'a>, + }, + AnnAssign { + target: ComparableExpr<'a>, + annotation: ComparableExpr<'a>, + value: Option>, + simple: usize, + }, + For { + target: ComparableExpr<'a>, + iter: ComparableExpr<'a>, + body: Vec>, + orelse: Vec>, + type_comment: Option<&'a str>, + }, + AsyncFor { + target: ComparableExpr<'a>, + iter: ComparableExpr<'a>, + body: Vec>, + orelse: Vec>, + type_comment: Option<&'a str>, + }, + While { + test: ComparableExpr<'a>, + body: Vec>, + orelse: Vec>, + }, + If { + test: ComparableExpr<'a>, + body: Vec>, + orelse: Vec>, + }, + With { + items: Vec>, + body: Vec>, + type_comment: Option<&'a str>, + }, + AsyncWith { + items: Vec>, + body: Vec>, + type_comment: Option<&'a str>, + }, + Raise { + exc: Option>, + cause: Option>, + }, + Try { + body: Vec>, + handlers: Vec>, + orelse: Vec>, + finalbody: Vec>, + }, + Assert { + test: ComparableExpr<'a>, + msg: Option>, + }, + Import { + names: Vec>, + }, + ImportFrom { + module: Option<&'a str>, + names: Vec>, + level: Option, + }, + Global { + names: Vec<&'a str>, + }, + Nonlocal { + names: Vec<&'a str>, + }, + Expr { + value: ComparableExpr<'a>, + }, + Pass, + Break, + Continue, +} + +impl<'a> From<&'a Stmt> for ComparableStmt<'a> { + fn from(stmt: &'a Stmt) -> Self { + match &stmt.node { + StmtKind::FunctionDef { + name, + args, + body, + decorator_list, + returns, + type_comment, + } => Self::FunctionDef { + name, + args: args.into(), + body: body.iter().map(Into::into).collect(), + decorator_list: decorator_list.iter().map(Into::into).collect(), + returns: returns.as_ref().map(Into::into), + type_comment: type_comment.as_ref().map(std::string::String::as_str), + }, + StmtKind::AsyncFunctionDef { + name, + args, + body, + decorator_list, + returns, + type_comment, + } => Self::AsyncFunctionDef { + name, + args: args.into(), + body: body.iter().map(Into::into).collect(), + decorator_list: decorator_list.iter().map(Into::into).collect(), + returns: returns.as_ref().map(Into::into), + type_comment: type_comment.as_ref().map(std::string::String::as_str), + }, + StmtKind::ClassDef { + name, + bases, + keywords, + body, + decorator_list, + } => Self::ClassDef { + name, + bases: bases.iter().map(Into::into).collect(), + keywords: keywords.iter().map(Into::into).collect(), + body: body.iter().map(Into::into).collect(), + decorator_list: decorator_list.iter().map(Into::into).collect(), + }, + StmtKind::Return { value } => Self::Return { + value: value.as_ref().map(Into::into), + }, + StmtKind::Delete { targets } => Self::Delete { + targets: targets.iter().map(Into::into).collect(), + }, + StmtKind::Assign { + targets, + value, + type_comment, + } => Self::Assign { + targets: targets.iter().map(Into::into).collect(), + value: value.into(), + type_comment: type_comment.as_ref().map(std::string::String::as_str), + }, + StmtKind::AugAssign { target, op, value } => Self::AugAssign { + target: target.into(), + op: op.into(), + value: value.into(), + }, + StmtKind::AnnAssign { + target, + annotation, + value, + simple, + } => Self::AnnAssign { + target: target.into(), + annotation: annotation.into(), + value: value.as_ref().map(Into::into), + simple: *simple, + }, + StmtKind::For { + target, + iter, + body, + orelse, + type_comment, + } => Self::For { + target: target.into(), + iter: iter.into(), + body: body.iter().map(Into::into).collect(), + orelse: orelse.iter().map(Into::into).collect(), + type_comment: type_comment.as_ref().map(String::as_str), + }, + StmtKind::AsyncFor { + target, + iter, + body, + orelse, + type_comment, + } => Self::AsyncFor { + target: target.into(), + iter: iter.into(), + body: body.iter().map(Into::into).collect(), + orelse: orelse.iter().map(Into::into).collect(), + type_comment: type_comment.as_ref().map(String::as_str), + }, + StmtKind::While { test, body, orelse } => Self::While { + test: test.into(), + body: body.iter().map(Into::into).collect(), + orelse: orelse.iter().map(Into::into).collect(), + }, + StmtKind::If { test, body, orelse } => Self::If { + test: test.into(), + body: body.iter().map(Into::into).collect(), + orelse: orelse.iter().map(Into::into).collect(), + }, + StmtKind::With { + items, + body, + type_comment, + } => Self::With { + items: items.iter().map(Into::into).collect(), + body: body.iter().map(Into::into).collect(), + type_comment: type_comment.as_ref().map(String::as_str), + }, + StmtKind::AsyncWith { + items, + body, + type_comment, + } => Self::AsyncWith { + items: items.iter().map(Into::into).collect(), + body: body.iter().map(Into::into).collect(), + type_comment: type_comment.as_ref().map(String::as_str), + }, + StmtKind::Match { .. } => unreachable!("StmtKind::Match is not supported"), + StmtKind::Raise { exc, cause } => Self::Raise { + exc: exc.as_ref().map(Into::into), + cause: cause.as_ref().map(Into::into), + }, + StmtKind::Try { + body, + handlers, + orelse, + finalbody, + } => Self::Try { + body: body.iter().map(Into::into).collect(), + handlers: handlers.iter().map(Into::into).collect(), + orelse: orelse.iter().map(Into::into).collect(), + finalbody: finalbody.iter().map(Into::into).collect(), + }, + StmtKind::Assert { test, msg } => Self::Assert { + test: test.into(), + msg: msg.as_ref().map(Into::into), + }, + StmtKind::Import { names } => Self::Import { + names: names.iter().map(Into::into).collect(), + }, + StmtKind::ImportFrom { + module, + names, + level, + } => Self::ImportFrom { + module: module.as_ref().map(String::as_str), + names: names.iter().map(Into::into).collect(), + level: *level, + }, + StmtKind::Global { names } => Self::Global { + names: names.iter().map(String::as_str).collect(), + }, + StmtKind::Nonlocal { names } => Self::Nonlocal { + names: names.iter().map(String::as_str).collect(), + }, + StmtKind::Expr { value } => Self::Expr { + value: value.into(), + }, + StmtKind::Pass => Self::Pass, + StmtKind::Break => Self::Break, + StmtKind::Continue => Self::Continue, + } + } +} diff --git a/crates/ruff/src/rules/flake8_simplify/rules/ast_if.rs b/crates/ruff/src/rules/flake8_simplify/rules/ast_if.rs index 820530e93b..31e277b8ad 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/ast_if.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/ast_if.rs @@ -1,11 +1,9 @@ use log::error; use rustpython_parser::ast::{Cmpop, Constant, Expr, ExprContext, ExprKind, Stmt, StmtKind}; -use rustpython_parser::lexer; -use rustpython_parser::lexer::Tok; use ruff_macros::{define_violation, derive_message_formats}; -use crate::ast::comparable::ComparableExpr; +use crate::ast::comparable::{ComparableExpr, ComparableStmt}; use crate::ast::helpers::{ contains_call_path, contains_effect, create_expr, create_stmt, first_colon_range, has_comments, has_comments_in, unparse_expr, unparse_stmt, @@ -15,9 +13,26 @@ use crate::checkers::ast::Checker; use crate::fix::Fix; use crate::registry::Diagnostic; use crate::rules::flake8_simplify::rules::fix_if; -use crate::source_code::Locator; use crate::violation::{AutofixKind, Availability, Violation}; +fn compare_expr(expr1: &ComparableExpr, expr2: &ComparableExpr) -> bool { + expr1.eq(expr2) +} + +fn compare_stmt(stmt1: &ComparableStmt, stmt2: &ComparableStmt) -> bool { + stmt1.eq(stmt2) +} + +fn compare_body(body1: &[Stmt], body2: &[Stmt]) -> bool { + if body1.len() != body2.len() { + return false; + } + body1 + .iter() + .zip(body2.iter()) + .all(|(stmt1, stmt2)| compare_stmt(&stmt1.into(), &stmt2.into())) +} + define_violation!( pub struct CollapsibleIf { pub fixable: bool, @@ -482,10 +497,6 @@ pub fn use_ternary_operator(checker: &mut Checker, stmt: &Stmt, parent: Option<& checker.diagnostics.push(diagnostic); } -fn compare_expr(expr1: &ComparableExpr, expr2: &ComparableExpr) -> bool { - expr1.eq(expr2) -} - fn get_if_body_pairs(orelse: &[Stmt], result: &mut Vec>) { if orelse.is_empty() { return; @@ -515,28 +526,6 @@ fn get_if_body_pairs(orelse: &[Stmt], result: &mut Vec>) { } } -pub fn is_equal(locator: &Locator, stmts1: &[Stmt], stmts2: &[Stmt]) -> bool { - if stmts1.len() != stmts2.len() { - return false; - } - for (stmt1, stmt2) in stmts1.iter().zip(stmts2.iter()) { - let text1 = locator.slice_source_code_range(&Range::from_located(stmt1)); - let text2 = locator.slice_source_code_range(&Range::from_located(stmt2)); - let lexer1: Vec = lexer::make_tokenizer(text1) - .flatten() - .map(|(_, tok, _)| tok) - .collect(); - let lexer2: Vec = lexer::make_tokenizer(text2) - .flatten() - .map(|(_, tok, _)| tok) - .collect(); - if lexer1 != lexer2 { - return false; - } - } - true -} - /// SIM114 pub fn if_with_same_arms(checker: &mut Checker, body: &[Stmt], orelse: &[Stmt]) { if orelse.is_empty() { @@ -553,7 +542,7 @@ pub fn if_with_same_arms(checker: &mut Checker, body: &[Stmt], orelse: &[Stmt]) } for i in 0..(if_statements - 1) { - if is_equal(checker.locator, &final_stmts[i], &final_stmts[i + 1]) { + if compare_body(&final_stmts[i], &final_stmts[i + 1]) { let first = &final_stmts[i].first().unwrap(); let last = &final_stmts[i].last().unwrap(); checker.diagnostics.push(Diagnostic::new( diff --git a/crates/ruff/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM114_SIM114.py.snap b/crates/ruff/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM114_SIM114.py.snap index fc9964771d..695f966d26 100644 --- a/crates/ruff/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM114_SIM114.py.snap +++ b/crates/ruff/src/rules/flake8_simplify/snapshots/ruff__rules__flake8_simplify__tests__SIM114_SIM114.py.snap @@ -72,4 +72,44 @@ expression: diagnostics column: 8 fix: ~ parent: ~ +- kind: + IfWithSameArms: ~ + location: + row: 66 + column: 4 + end_location: + row: 66 + column: 14 + fix: ~ + parent: ~ +- kind: + IfWithSameArms: ~ + location: + row: 66 + column: 4 + end_location: + row: 66 + column: 14 + fix: ~ + parent: ~ +- kind: + IfWithSameArms: ~ + location: + row: 66 + column: 4 + end_location: + row: 66 + column: 14 + fix: ~ + parent: ~ +- kind: + IfWithSameArms: ~ + location: + row: 66 + column: 4 + end_location: + row: 66 + column: 14 + fix: ~ + parent: ~