flake8_simplify : SIM401 (#1778)

Ref #998 

- Implements SIM401 with fix
- Added tests

Notes: 
- only recognize simple ExprKind::Name variables in expr patterns for
now
- bug-fix from reference implementation: check 3-conditions (dict-key,
target-variable, dict-name) to be equal, `flake8_simplify` only test
first two (only first in second pattern)
This commit is contained in:
Chammika Mannakkara 2023-01-12 09:51:37 +09:00 committed by GitHub
parent de81b0cd38
commit 4523885268
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 334 additions and 7 deletions

View File

@ -999,6 +999,7 @@ For more, see [flake8-simplify](https://pypi.org/project/flake8-simplify/0.19.3/
| SIM222 | OrTrue | Use `True` instead of `... or True` | 🛠 |
| SIM223 | AndFalse | Use `False` instead of `... and False` | 🛠 |
| SIM300 | YodaConditions | Yoda conditions are discouraged, use `left == right` instead | 🛠 |
| SIM401 | DictGetWithDefault | Use `var = dict.get(key, "default")` instead of an `if` block | 🛠 |
### flake8-tidy-imports (TID)

View File

@ -0,0 +1,81 @@
###
# Positive cases
###
# SIM401 (pattern-1)
if key in a_dict:
var = a_dict[key]
else:
var = "default1"
# SIM401 (pattern-2)
if key not in a_dict:
var = "default2"
else:
var = a_dict[key]
# SIM401 (default with a complex expression)
if key in a_dict:
var = a_dict[key]
else:
var = val1 + val2
# SIM401 (complex expression in key)
if keys[idx] in a_dict:
var = a_dict[keys[idx]]
else:
var = "default"
# SIM401 (complex expression in dict)
if key in dicts[idx]:
var = dicts[idx][key]
else:
var = "default"
# SIM401 (complex expression in var)
if key in a_dict:
vars[idx] = a_dict[key]
else:
vars[idx] = "default"
###
# Negative cases
###
# OK (false negative)
if not key in a_dict:
var = "default"
else:
var = a_dict[key]
# OK (different dict)
if key in a_dict:
var = other_dict[key]
else:
var = "default"
# OK (different key)
if key in a_dict:
var = a_dict[other_key]
else:
var = "default"
# OK (different var)
if key in a_dict:
var = a_dict[key]
else:
other_var = "default"
# OK (extra vars in body)
if key in a_dict:
var = a_dict[key]
var2 = value2
else:
var = "default"
# OK (extra vars in orelse)
if key in a_dict:
var = a_dict[key]
else:
var2 = value2
var = "default"

View File

@ -1540,6 +1540,9 @@
"SIM3",
"SIM30",
"SIM300",
"SIM4",
"SIM40",
"SIM401",
"T",
"T1",
"T10",

View File

@ -388,6 +388,12 @@ impl<'a> From<&'a Box<Expr>> for Box<ComparableExpr<'a>> {
}
}
impl<'a> From<&'a Box<Expr>> for ComparableExpr<'a> {
fn from(expr: &'a Box<Expr>) -> Self {
(&**expr).into()
}
}
impl<'a> From<&'a Expr> for ComparableExpr<'a> {
fn from(expr: &'a Expr) -> Self {
match &expr.node {

View File

@ -1214,7 +1214,7 @@ where
StmtKind::AugAssign { target, .. } => {
self.handle_node_load(target);
}
StmtKind::If { test, .. } => {
StmtKind::If { test, body, orelse } => {
if self.settings.enabled.contains(&RuleCode::F634) {
pyflakes::rules::if_tuple(self, stmt, test);
}
@ -1231,6 +1231,11 @@ where
self.current_stmt_parent().map(|parent| parent.0),
);
}
if self.settings.enabled.contains(&RuleCode::SIM401) {
flake8_simplify::rules::use_dict_get_with_default(
self, stmt, test, body, orelse,
);
}
}
StmtKind::Assert { test, msg } => {
if self.settings.enabled.contains(&RuleCode::F631) {

View File

@ -36,6 +36,7 @@ mod tests {
#[test_case(RuleCode::SIM222, Path::new("SIM222.py"); "SIM222")]
#[test_case(RuleCode::SIM223, Path::new("SIM223.py"); "SIM223")]
#[test_case(RuleCode::SIM300, Path::new("SIM300.py"); "SIM300")]
#[test_case(RuleCode::SIM401, Path::new("SIM401.py"); "SIM401")]
fn rules(rule_code: RuleCode, path: &Path) -> Result<()> {
let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy());
let diagnostics = test_path(

View File

@ -1,5 +1,6 @@
use rustpython_ast::{Constant, Expr, ExprKind, Stmt, StmtKind};
use rustpython_ast::{Cmpop, Constant, Expr, ExprContext, ExprKind, Stmt, StmtKind};
use crate::ast::comparable::ComparableExpr;
use crate::ast::helpers::{
contains_call_path, create_expr, create_stmt, has_comments, unparse_expr, unparse_stmt,
};
@ -228,3 +229,103 @@ pub fn use_ternary_operator(checker: &mut Checker, stmt: &Stmt, parent: Option<&
}
checker.diagnostics.push(diagnostic);
}
fn compare_expr(expr1: &ComparableExpr, expr2: &ComparableExpr) -> bool {
expr1.eq(&expr2)
}
/// SIM401
pub fn use_dict_get_with_default(
checker: &mut Checker,
stmt: &Stmt,
test: &Expr,
body: &Vec<Stmt>,
orelse: &Vec<Stmt>,
) {
if body.len() != 1 || orelse.len() != 1 {
return;
}
let StmtKind::Assign { targets: body_var, value: body_val, ..} = &body[0].node else {
return;
};
if body_var.len() != 1 {
return;
};
let StmtKind::Assign { targets: orelse_var, value: orelse_val, .. } = &orelse[0].node else {
return;
};
if orelse_var.len() != 1 {
return;
};
let ExprKind::Compare { left: test_key, ops , comparators: test_dict } = &test.node else {
return;
};
if test_dict.len() != 1 {
return;
}
let (expected_var, expected_val, default_var, default_val) = match ops[..] {
[Cmpop::In] => (&body_var[0], body_val, &orelse_var[0], orelse_val),
[Cmpop::NotIn] => (&orelse_var[0], orelse_val, &body_var[0], body_val),
_ => {
return;
}
};
let test_dict = &test_dict[0];
let ExprKind::Subscript { value: expected_subscript, slice: expected_slice, .. } = &expected_val.node else {
return;
};
// Check that the dictionary key, target variables, and dictionary name are all
// equivalent.
if !compare_expr(&expected_slice.into(), &test_key.into())
|| !compare_expr(&expected_var.into(), &default_var.into())
|| !compare_expr(&test_dict.into(), &expected_subscript.into())
{
return;
}
let contents = unparse_stmt(
&create_stmt(StmtKind::Assign {
targets: vec![create_expr(expected_var.node.clone())],
value: Box::new(create_expr(ExprKind::Call {
func: Box::new(create_expr(ExprKind::Attribute {
value: expected_subscript.clone(),
attr: "get".to_string(),
ctx: ExprContext::Load,
})),
args: vec![
create_expr(test_key.node.clone()),
create_expr(default_val.node.clone()),
],
keywords: vec![],
})),
type_comment: None,
}),
checker.style,
);
// Don't flag for simplified `dict.get` if the resulting expression would exceed
// the maximum line length.
if stmt.location.column() + contents.len() > checker.settings.line_length {
return;
}
// Don't flag for simplified `dict.get` if the if-expression contains any
// comments.
if has_comments(stmt, checker.locator) {
return;
}
let mut diagnostic = Diagnostic::new(
violations::DictGetWithDefault(contents.clone()),
Range::from_located(stmt),
);
if checker.patch(&RuleCode::SIM401) {
diagnostic.amend(Fix::replacement(
contents,
stmt.location,
stmt.end_location.unwrap(),
));
}
checker.diagnostics.push(diagnostic);
}

View File

@ -3,7 +3,10 @@ pub use ast_bool_op::{
};
pub use ast_expr::use_capital_environment_variables;
pub use ast_for::convert_loop_to_any_all;
pub use ast_if::{nested_if_statements, return_bool_condition_directly, use_ternary_operator};
pub use ast_if::{
nested_if_statements, return_bool_condition_directly, use_dict_get_with_default,
use_ternary_operator,
};
pub use ast_ifexp::{
explicit_false_true_in_ifexpr, explicit_true_false_in_ifexpr, twisted_arms_in_ifexpr,
};

View File

@ -0,0 +1,107 @@
---
source: src/flake8_simplify/mod.rs
expression: diagnostics
---
- kind:
DictGetWithDefault: "var = a_dict.get(key, \"default1\")"
location:
row: 6
column: 0
end_location:
row: 9
column: 20
fix:
content: "var = a_dict.get(key, \"default1\")"
location:
row: 6
column: 0
end_location:
row: 9
column: 20
parent: ~
- kind:
DictGetWithDefault: "var = a_dict.get(key, \"default2\")"
location:
row: 12
column: 0
end_location:
row: 15
column: 21
fix:
content: "var = a_dict.get(key, \"default2\")"
location:
row: 12
column: 0
end_location:
row: 15
column: 21
parent: ~
- kind:
DictGetWithDefault: "var = a_dict.get(key, val1 + val2)"
location:
row: 18
column: 0
end_location:
row: 21
column: 21
fix:
content: "var = a_dict.get(key, val1 + val2)"
location:
row: 18
column: 0
end_location:
row: 21
column: 21
parent: ~
- kind:
DictGetWithDefault: "var = a_dict.get(keys[idx], \"default\")"
location:
row: 24
column: 0
end_location:
row: 27
column: 19
fix:
content: "var = a_dict.get(keys[idx], \"default\")"
location:
row: 24
column: 0
end_location:
row: 27
column: 19
parent: ~
- kind:
DictGetWithDefault: "var = dicts[idx].get(key, \"default\")"
location:
row: 30
column: 0
end_location:
row: 33
column: 19
fix:
content: "var = dicts[idx].get(key, \"default\")"
location:
row: 30
column: 0
end_location:
row: 33
column: 19
parent: ~
- kind:
DictGetWithDefault: "vars[idx] = a_dict.get(key, \"default\")"
location:
row: 36
column: 0
end_location:
row: 39
column: 25
fix:
content: "vars[idx] = a_dict.get(key, \"default\")"
location:
row: 36
column: 0
end_location:
row: 39
column: 25
parent: ~

View File

@ -318,6 +318,7 @@ define_rule_mapping!(
SIM222 => violations::OrTrue,
SIM223 => violations::AndFalse,
SIM300 => violations::YodaConditions,
SIM401 => violations::DictGetWithDefault,
// pyupgrade
UP001 => violations::UselessMetaclassType,
UP003 => violations::TypeOfPrimitive,

View File

@ -2793,13 +2793,13 @@ define_violation!(
);
impl AlwaysAutofixableViolation for UseTernaryOperator {
fn message(&self) -> String {
let UseTernaryOperator(new_code) = self;
format!("Use ternary operator `{new_code}` instead of if-else-block")
let UseTernaryOperator(contents) = self;
format!("Use ternary operator `{contents}` instead of if-else-block")
}
fn autofix_title(&self) -> String {
let UseTernaryOperator(new_code) = self;
format!("Replace if-else-block with `{new_code}`")
let UseTernaryOperator(contents) = self;
format!("Replace if-else-block with `{contents}`")
}
fn placeholder() -> Self {
@ -3107,6 +3107,24 @@ impl AlwaysAutofixableViolation for IfExprWithTwistedArms {
}
}
define_violation!(
pub struct DictGetWithDefault(pub String);
);
impl AlwaysAutofixableViolation for DictGetWithDefault {
fn message(&self) -> String {
let DictGetWithDefault(contents) = self;
format!("Use `{contents}` instead of an `if` block")
}
fn autofix_title(&self) -> String {
let DictGetWithDefault(contents) = self;
format!("Replace with `{contents}`")
}
fn placeholder() -> Self {
DictGetWithDefault("var = dict.get(key, \"default\")".to_string())
}
}
// pyupgrade
define_violation!(