mirror of https://github.com/astral-sh/ruff
flake8_simplify : SIM401 (#1778)
Ref #998 - Implements SIM401 with fix - Added tests Notes: - only recognize simple ExprKind::Name variables in expr patterns for now - bug-fix from reference implementation: check 3-conditions (dict-key, target-variable, dict-name) to be equal, `flake8_simplify` only test first two (only first in second pattern)
This commit is contained in:
parent
de81b0cd38
commit
4523885268
|
|
@ -999,6 +999,7 @@ For more, see [flake8-simplify](https://pypi.org/project/flake8-simplify/0.19.3/
|
|||
| SIM222 | OrTrue | Use `True` instead of `... or True` | 🛠 |
|
||||
| SIM223 | AndFalse | Use `False` instead of `... and False` | 🛠 |
|
||||
| SIM300 | YodaConditions | Yoda conditions are discouraged, use `left == right` instead | 🛠 |
|
||||
| SIM401 | DictGetWithDefault | Use `var = dict.get(key, "default")` instead of an `if` block | 🛠 |
|
||||
|
||||
### flake8-tidy-imports (TID)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
###
|
||||
# Positive cases
|
||||
###
|
||||
|
||||
# SIM401 (pattern-1)
|
||||
if key in a_dict:
|
||||
var = a_dict[key]
|
||||
else:
|
||||
var = "default1"
|
||||
|
||||
# SIM401 (pattern-2)
|
||||
if key not in a_dict:
|
||||
var = "default2"
|
||||
else:
|
||||
var = a_dict[key]
|
||||
|
||||
# SIM401 (default with a complex expression)
|
||||
if key in a_dict:
|
||||
var = a_dict[key]
|
||||
else:
|
||||
var = val1 + val2
|
||||
|
||||
# SIM401 (complex expression in key)
|
||||
if keys[idx] in a_dict:
|
||||
var = a_dict[keys[idx]]
|
||||
else:
|
||||
var = "default"
|
||||
|
||||
# SIM401 (complex expression in dict)
|
||||
if key in dicts[idx]:
|
||||
var = dicts[idx][key]
|
||||
else:
|
||||
var = "default"
|
||||
|
||||
# SIM401 (complex expression in var)
|
||||
if key in a_dict:
|
||||
vars[idx] = a_dict[key]
|
||||
else:
|
||||
vars[idx] = "default"
|
||||
|
||||
###
|
||||
# Negative cases
|
||||
###
|
||||
|
||||
# OK (false negative)
|
||||
if not key in a_dict:
|
||||
var = "default"
|
||||
else:
|
||||
var = a_dict[key]
|
||||
|
||||
# OK (different dict)
|
||||
if key in a_dict:
|
||||
var = other_dict[key]
|
||||
else:
|
||||
var = "default"
|
||||
|
||||
# OK (different key)
|
||||
if key in a_dict:
|
||||
var = a_dict[other_key]
|
||||
else:
|
||||
var = "default"
|
||||
|
||||
# OK (different var)
|
||||
if key in a_dict:
|
||||
var = a_dict[key]
|
||||
else:
|
||||
other_var = "default"
|
||||
|
||||
# OK (extra vars in body)
|
||||
if key in a_dict:
|
||||
var = a_dict[key]
|
||||
var2 = value2
|
||||
else:
|
||||
var = "default"
|
||||
|
||||
# OK (extra vars in orelse)
|
||||
if key in a_dict:
|
||||
var = a_dict[key]
|
||||
else:
|
||||
var2 = value2
|
||||
var = "default"
|
||||
|
|
@ -1540,6 +1540,9 @@
|
|||
"SIM3",
|
||||
"SIM30",
|
||||
"SIM300",
|
||||
"SIM4",
|
||||
"SIM40",
|
||||
"SIM401",
|
||||
"T",
|
||||
"T1",
|
||||
"T10",
|
||||
|
|
|
|||
|
|
@ -388,6 +388,12 @@ impl<'a> From<&'a Box<Expr>> for Box<ComparableExpr<'a>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a Box<Expr>> for ComparableExpr<'a> {
|
||||
fn from(expr: &'a Box<Expr>) -> Self {
|
||||
(&**expr).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a Expr> for ComparableExpr<'a> {
|
||||
fn from(expr: &'a Expr) -> Self {
|
||||
match &expr.node {
|
||||
|
|
|
|||
|
|
@ -1214,7 +1214,7 @@ where
|
|||
StmtKind::AugAssign { target, .. } => {
|
||||
self.handle_node_load(target);
|
||||
}
|
||||
StmtKind::If { test, .. } => {
|
||||
StmtKind::If { test, body, orelse } => {
|
||||
if self.settings.enabled.contains(&RuleCode::F634) {
|
||||
pyflakes::rules::if_tuple(self, stmt, test);
|
||||
}
|
||||
|
|
@ -1231,6 +1231,11 @@ where
|
|||
self.current_stmt_parent().map(|parent| parent.0),
|
||||
);
|
||||
}
|
||||
if self.settings.enabled.contains(&RuleCode::SIM401) {
|
||||
flake8_simplify::rules::use_dict_get_with_default(
|
||||
self, stmt, test, body, orelse,
|
||||
);
|
||||
}
|
||||
}
|
||||
StmtKind::Assert { test, msg } => {
|
||||
if self.settings.enabled.contains(&RuleCode::F631) {
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ mod tests {
|
|||
#[test_case(RuleCode::SIM222, Path::new("SIM222.py"); "SIM222")]
|
||||
#[test_case(RuleCode::SIM223, Path::new("SIM223.py"); "SIM223")]
|
||||
#[test_case(RuleCode::SIM300, Path::new("SIM300.py"); "SIM300")]
|
||||
#[test_case(RuleCode::SIM401, Path::new("SIM401.py"); "SIM401")]
|
||||
fn rules(rule_code: RuleCode, path: &Path) -> Result<()> {
|
||||
let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy());
|
||||
let diagnostics = test_path(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use rustpython_ast::{Constant, Expr, ExprKind, Stmt, StmtKind};
|
||||
use rustpython_ast::{Cmpop, Constant, Expr, ExprContext, ExprKind, Stmt, StmtKind};
|
||||
|
||||
use crate::ast::comparable::ComparableExpr;
|
||||
use crate::ast::helpers::{
|
||||
contains_call_path, create_expr, create_stmt, has_comments, unparse_expr, unparse_stmt,
|
||||
};
|
||||
|
|
@ -228,3 +229,103 @@ pub fn use_ternary_operator(checker: &mut Checker, stmt: &Stmt, parent: Option<&
|
|||
}
|
||||
checker.diagnostics.push(diagnostic);
|
||||
}
|
||||
|
||||
fn compare_expr(expr1: &ComparableExpr, expr2: &ComparableExpr) -> bool {
|
||||
expr1.eq(&expr2)
|
||||
}
|
||||
|
||||
/// SIM401
|
||||
pub fn use_dict_get_with_default(
|
||||
checker: &mut Checker,
|
||||
stmt: &Stmt,
|
||||
test: &Expr,
|
||||
body: &Vec<Stmt>,
|
||||
orelse: &Vec<Stmt>,
|
||||
) {
|
||||
if body.len() != 1 || orelse.len() != 1 {
|
||||
return;
|
||||
}
|
||||
let StmtKind::Assign { targets: body_var, value: body_val, ..} = &body[0].node else {
|
||||
return;
|
||||
};
|
||||
if body_var.len() != 1 {
|
||||
return;
|
||||
};
|
||||
let StmtKind::Assign { targets: orelse_var, value: orelse_val, .. } = &orelse[0].node else {
|
||||
return;
|
||||
};
|
||||
if orelse_var.len() != 1 {
|
||||
return;
|
||||
};
|
||||
let ExprKind::Compare { left: test_key, ops , comparators: test_dict } = &test.node else {
|
||||
return;
|
||||
};
|
||||
if test_dict.len() != 1 {
|
||||
return;
|
||||
}
|
||||
let (expected_var, expected_val, default_var, default_val) = match ops[..] {
|
||||
[Cmpop::In] => (&body_var[0], body_val, &orelse_var[0], orelse_val),
|
||||
[Cmpop::NotIn] => (&orelse_var[0], orelse_val, &body_var[0], body_val),
|
||||
_ => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
let test_dict = &test_dict[0];
|
||||
let ExprKind::Subscript { value: expected_subscript, slice: expected_slice, .. } = &expected_val.node else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Check that the dictionary key, target variables, and dictionary name are all
|
||||
// equivalent.
|
||||
if !compare_expr(&expected_slice.into(), &test_key.into())
|
||||
|| !compare_expr(&expected_var.into(), &default_var.into())
|
||||
|| !compare_expr(&test_dict.into(), &expected_subscript.into())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let contents = unparse_stmt(
|
||||
&create_stmt(StmtKind::Assign {
|
||||
targets: vec![create_expr(expected_var.node.clone())],
|
||||
value: Box::new(create_expr(ExprKind::Call {
|
||||
func: Box::new(create_expr(ExprKind::Attribute {
|
||||
value: expected_subscript.clone(),
|
||||
attr: "get".to_string(),
|
||||
ctx: ExprContext::Load,
|
||||
})),
|
||||
args: vec![
|
||||
create_expr(test_key.node.clone()),
|
||||
create_expr(default_val.node.clone()),
|
||||
],
|
||||
keywords: vec![],
|
||||
})),
|
||||
type_comment: None,
|
||||
}),
|
||||
checker.style,
|
||||
);
|
||||
|
||||
// Don't flag for simplified `dict.get` if the resulting expression would exceed
|
||||
// the maximum line length.
|
||||
if stmt.location.column() + contents.len() > checker.settings.line_length {
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't flag for simplified `dict.get` if the if-expression contains any
|
||||
// comments.
|
||||
if has_comments(stmt, checker.locator) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut diagnostic = Diagnostic::new(
|
||||
violations::DictGetWithDefault(contents.clone()),
|
||||
Range::from_located(stmt),
|
||||
);
|
||||
if checker.patch(&RuleCode::SIM401) {
|
||||
diagnostic.amend(Fix::replacement(
|
||||
contents,
|
||||
stmt.location,
|
||||
stmt.end_location.unwrap(),
|
||||
));
|
||||
}
|
||||
checker.diagnostics.push(diagnostic);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ pub use ast_bool_op::{
|
|||
};
|
||||
pub use ast_expr::use_capital_environment_variables;
|
||||
pub use ast_for::convert_loop_to_any_all;
|
||||
pub use ast_if::{nested_if_statements, return_bool_condition_directly, use_ternary_operator};
|
||||
pub use ast_if::{
|
||||
nested_if_statements, return_bool_condition_directly, use_dict_get_with_default,
|
||||
use_ternary_operator,
|
||||
};
|
||||
pub use ast_ifexp::{
|
||||
explicit_false_true_in_ifexpr, explicit_true_false_in_ifexpr, twisted_arms_in_ifexpr,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,107 @@
|
|||
---
|
||||
source: src/flake8_simplify/mod.rs
|
||||
expression: diagnostics
|
||||
---
|
||||
- kind:
|
||||
DictGetWithDefault: "var = a_dict.get(key, \"default1\")"
|
||||
location:
|
||||
row: 6
|
||||
column: 0
|
||||
end_location:
|
||||
row: 9
|
||||
column: 20
|
||||
fix:
|
||||
content: "var = a_dict.get(key, \"default1\")"
|
||||
location:
|
||||
row: 6
|
||||
column: 0
|
||||
end_location:
|
||||
row: 9
|
||||
column: 20
|
||||
parent: ~
|
||||
- kind:
|
||||
DictGetWithDefault: "var = a_dict.get(key, \"default2\")"
|
||||
location:
|
||||
row: 12
|
||||
column: 0
|
||||
end_location:
|
||||
row: 15
|
||||
column: 21
|
||||
fix:
|
||||
content: "var = a_dict.get(key, \"default2\")"
|
||||
location:
|
||||
row: 12
|
||||
column: 0
|
||||
end_location:
|
||||
row: 15
|
||||
column: 21
|
||||
parent: ~
|
||||
- kind:
|
||||
DictGetWithDefault: "var = a_dict.get(key, val1 + val2)"
|
||||
location:
|
||||
row: 18
|
||||
column: 0
|
||||
end_location:
|
||||
row: 21
|
||||
column: 21
|
||||
fix:
|
||||
content: "var = a_dict.get(key, val1 + val2)"
|
||||
location:
|
||||
row: 18
|
||||
column: 0
|
||||
end_location:
|
||||
row: 21
|
||||
column: 21
|
||||
parent: ~
|
||||
- kind:
|
||||
DictGetWithDefault: "var = a_dict.get(keys[idx], \"default\")"
|
||||
location:
|
||||
row: 24
|
||||
column: 0
|
||||
end_location:
|
||||
row: 27
|
||||
column: 19
|
||||
fix:
|
||||
content: "var = a_dict.get(keys[idx], \"default\")"
|
||||
location:
|
||||
row: 24
|
||||
column: 0
|
||||
end_location:
|
||||
row: 27
|
||||
column: 19
|
||||
parent: ~
|
||||
- kind:
|
||||
DictGetWithDefault: "var = dicts[idx].get(key, \"default\")"
|
||||
location:
|
||||
row: 30
|
||||
column: 0
|
||||
end_location:
|
||||
row: 33
|
||||
column: 19
|
||||
fix:
|
||||
content: "var = dicts[idx].get(key, \"default\")"
|
||||
location:
|
||||
row: 30
|
||||
column: 0
|
||||
end_location:
|
||||
row: 33
|
||||
column: 19
|
||||
parent: ~
|
||||
- kind:
|
||||
DictGetWithDefault: "vars[idx] = a_dict.get(key, \"default\")"
|
||||
location:
|
||||
row: 36
|
||||
column: 0
|
||||
end_location:
|
||||
row: 39
|
||||
column: 25
|
||||
fix:
|
||||
content: "vars[idx] = a_dict.get(key, \"default\")"
|
||||
location:
|
||||
row: 36
|
||||
column: 0
|
||||
end_location:
|
||||
row: 39
|
||||
column: 25
|
||||
parent: ~
|
||||
|
||||
|
|
@ -318,6 +318,7 @@ define_rule_mapping!(
|
|||
SIM222 => violations::OrTrue,
|
||||
SIM223 => violations::AndFalse,
|
||||
SIM300 => violations::YodaConditions,
|
||||
SIM401 => violations::DictGetWithDefault,
|
||||
// pyupgrade
|
||||
UP001 => violations::UselessMetaclassType,
|
||||
UP003 => violations::TypeOfPrimitive,
|
||||
|
|
|
|||
|
|
@ -2793,13 +2793,13 @@ define_violation!(
|
|||
);
|
||||
impl AlwaysAutofixableViolation for UseTernaryOperator {
|
||||
fn message(&self) -> String {
|
||||
let UseTernaryOperator(new_code) = self;
|
||||
format!("Use ternary operator `{new_code}` instead of if-else-block")
|
||||
let UseTernaryOperator(contents) = self;
|
||||
format!("Use ternary operator `{contents}` instead of if-else-block")
|
||||
}
|
||||
|
||||
fn autofix_title(&self) -> String {
|
||||
let UseTernaryOperator(new_code) = self;
|
||||
format!("Replace if-else-block with `{new_code}`")
|
||||
let UseTernaryOperator(contents) = self;
|
||||
format!("Replace if-else-block with `{contents}`")
|
||||
}
|
||||
|
||||
fn placeholder() -> Self {
|
||||
|
|
@ -3107,6 +3107,24 @@ impl AlwaysAutofixableViolation for IfExprWithTwistedArms {
|
|||
}
|
||||
}
|
||||
|
||||
define_violation!(
|
||||
pub struct DictGetWithDefault(pub String);
|
||||
);
|
||||
impl AlwaysAutofixableViolation for DictGetWithDefault {
|
||||
fn message(&self) -> String {
|
||||
let DictGetWithDefault(contents) = self;
|
||||
format!("Use `{contents}` instead of an `if` block")
|
||||
}
|
||||
|
||||
fn autofix_title(&self) -> String {
|
||||
let DictGetWithDefault(contents) = self;
|
||||
format!("Replace with `{contents}`")
|
||||
}
|
||||
|
||||
fn placeholder() -> Self {
|
||||
DictGetWithDefault("var = dict.get(key, \"default\")".to_string())
|
||||
}
|
||||
}
|
||||
// pyupgrade
|
||||
|
||||
define_violation!(
|
||||
|
|
|
|||
Loading…
Reference in New Issue