From 349f93389e20dd0a77fed36f7521a883ff18f264 Mon Sep 17 00:00:00 2001 From: Junhson Jean-Baptiste Date: Fri, 7 Feb 2025 03:25:20 -0500 Subject: [PATCH] [flake8-simplify] Only trigger SIM401 on known dictionaries (SIM401) (#15995) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This change resolves #15814 to ensure that `SIM401` is only triggered on known dictionary types. Before, the rule was getting triggered even on types that _resemble_ a dictionary but are not actually a dictionary. I did this using the `is_known_to_be_of_type_dict(...)` functionality. The logic for this function was duplicated in a few spots, so I moved the code to a central location, removed redundant definitions, and updated existing calls to use the single definition of the function! ## Test Plan Since this PR only modifies an existing rule, I made changes to the existing test instead of adding new ones. I made sure that `SIM401` is triggered on types that are clearly dictionaries and that it's not triggered on a simple custom dictionary-like type (using a modified version of [the code in the issue](#15814)) The additional changes to de-duplicate `is_known_to_be_of_type_dict` don't break any existing tests -- I think this should be fine since the logic remains the same (please let me know if you think otherwise, I'm excited to get feedback and work towards a good fix πŸ™‚). --------- Co-authored-by: Junhson Jean-Baptiste Co-authored-by: Micha Reiser --- .../test/fixtures/flake8_simplify/SIM401.py | 26 ++ .../if_else_block_instead_of_dict_get.rs | 15 +- ...ke8_simplify__tests__SIM401_SIM401.py.snap | 270 ++++++++---------- .../ruff/rules/falsy_dict_get_fallback.rs | 13 +- .../rules/ruff/rules/if_key_in_dict_del.rs | 11 +- .../src/analyze/typing.rs | 11 +- 6 files changed, 174 insertions(+), 172 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py b/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py index 26bc35ee63..0bfa2499a1 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_simplify/SIM401.py @@ -2,6 +2,8 @@ # Positive cases ### +a_dict = {} + # SIM401 (pattern-1) if key in a_dict: var = a_dict[key] @@ -26,6 +28,8 @@ if keys[idx] in a_dict: else: var = "default" +dicts = {"key": a_dict} + # SIM401 (complex expression in dict) if key in dicts[idx]: var = dicts[idx][key] @@ -115,6 +119,28 @@ elif key in a_dict: else: vars[idx] = "default" +class NotADictionary: + def __init__(self): + self._dict = {} + + def __getitem__(self, key): + return self._dict[key] + + def __setitem__(self, key, value): + self._dict[key] = value + + def __iter__(self): + return self._dict.__iter__() + +not_dict = NotADictionary() +not_dict["key"] = "value" + +# OK (type `NotADictionary` is not a known dictionary type) +if "key" in not_dict: + value = not_dict["key"] +else: + value = None + ### # Positive cases (preview) ### diff --git a/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs b/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs index 2b9f2fdbc1..d090638e09 100644 --- a/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs +++ b/crates/ruff_linter/src/rules/flake8_simplify/rules/if_else_block_instead_of_dict_get.rs @@ -5,7 +5,9 @@ use ruff_python_ast::helpers::contains_effect; use ruff_python_ast::{ self as ast, Arguments, CmpOp, ElifElseClause, Expr, ExprContext, Identifier, Stmt, }; -use ruff_python_semantic::analyze::typing::{is_sys_version_block, is_type_checking_block}; +use ruff_python_semantic::analyze::typing::{ + is_known_to_be_of_type_dict, is_sys_version_block, is_type_checking_block, +}; use ruff_text_size::{Ranged, TextRange}; use crate::checkers::ast::Checker; @@ -113,18 +115,27 @@ pub(crate) fn if_else_block_instead_of_dict_get(checker: &mut Checker, stmt_if: let [orelse_var] = orelse_var.as_slice() else { return; }; + let Expr::Compare(ast::ExprCompare { left: test_key, ops, comparators: test_dict, range: _, - }) = test.as_ref() + }) = &**test else { return; }; let [test_dict] = &**test_dict else { return; }; + + if !test_dict + .as_name_expr() + .is_some_and(|dict_name| is_known_to_be_of_type_dict(checker.semantic(), dict_name)) + { + return; + } + let (expected_var, expected_value, default_var, default_value) = match ops[..] { [CmpOp::In] => (body_var, body_value, orelse_var, orelse_value.as_ref()), [CmpOp::NotIn] => (orelse_var, orelse_value, body_var, body_value.as_ref()), diff --git a/crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__SIM401_SIM401.py.snap b/crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__SIM401_SIM401.py.snap index 5b0983df10..a89ba6468d 100644 --- a/crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__SIM401_SIM401.py.snap +++ b/crates/ruff_linter/src/rules/flake8_simplify/snapshots/ruff_linter__rules__flake8_simplify__tests__SIM401_SIM401.py.snap @@ -1,199 +1,173 @@ --- source: crates/ruff_linter/src/rules/flake8_simplify/mod.rs --- -SIM401.py:6:1: SIM401 [*] Use `var = a_dict.get(key, "default1")` instead of an `if` block +SIM401.py:8:1: SIM401 [*] Use `var = a_dict.get(key, "default1")` instead of an `if` block | - 5 | # SIM401 (pattern-1) - 6 | / if key in a_dict: - 7 | | var = a_dict[key] - 8 | | else: - 9 | | var = "default1" + 7 | # SIM401 (pattern-1) + 8 | / if key in a_dict: + 9 | | var = a_dict[key] +10 | | else: +11 | | var = "default1" | |____________________^ SIM401 -10 | -11 | # SIM401 (pattern-2) +12 | +13 | # SIM401 (pattern-2) | = help: Replace with `var = a_dict.get(key, "default1")` β„Ή Unsafe fix -3 3 | ### -4 4 | -5 5 | # SIM401 (pattern-1) -6 |-if key in a_dict: -7 |- var = a_dict[key] -8 |-else: -9 |- var = "default1" - 6 |+var = a_dict.get(key, "default1") -10 7 | -11 8 | # SIM401 (pattern-2) -12 9 | if key not in a_dict: +5 5 | a_dict = {} +6 6 | +7 7 | # SIM401 (pattern-1) +8 |-if key in a_dict: +9 |- var = a_dict[key] +10 |-else: +11 |- var = "default1" + 8 |+var = a_dict.get(key, "default1") +12 9 | +13 10 | # SIM401 (pattern-2) +14 11 | if key not in a_dict: -SIM401.py:12:1: SIM401 [*] Use `var = a_dict.get(key, "default2")` instead of an `if` block +SIM401.py:14:1: SIM401 [*] Use `var = a_dict.get(key, "default2")` instead of an `if` block | -11 | # SIM401 (pattern-2) -12 | / if key not in a_dict: -13 | | var = "default2" -14 | | else: -15 | | var = a_dict[key] +13 | # SIM401 (pattern-2) +14 | / if key not in a_dict: +15 | | var = "default2" +16 | | else: +17 | | var = a_dict[key] | |_____________________^ SIM401 -16 | -17 | # OK (default contains effect) +18 | +19 | # OK (default contains effect) | = help: Replace with `var = a_dict.get(key, "default2")` β„Ή Unsafe fix -9 9 | var = "default1" -10 10 | -11 11 | # SIM401 (pattern-2) -12 |-if key not in a_dict: -13 |- var = "default2" -14 |-else: -15 |- var = a_dict[key] - 12 |+var = a_dict.get(key, "default2") -16 13 | -17 14 | # OK (default contains effect) -18 15 | if key in a_dict: +11 11 | var = "default1" +12 12 | +13 13 | # SIM401 (pattern-2) +14 |-if key not in a_dict: +15 |- var = "default2" +16 |-else: +17 |- var = a_dict[key] + 14 |+var = a_dict.get(key, "default2") +18 15 | +19 16 | # OK (default contains effect) +20 17 | if key in a_dict: -SIM401.py:24:1: SIM401 [*] Use `var = a_dict.get(keys[idx], "default")` instead of an `if` block +SIM401.py:26:1: SIM401 [*] Use `var = a_dict.get(keys[idx], "default")` instead of an `if` block | -23 | # SIM401 (complex expression in key) -24 | / if keys[idx] in a_dict: -25 | | var = a_dict[keys[idx]] -26 | | else: -27 | | var = "default" +25 | # SIM401 (complex expression in key) +26 | / if keys[idx] in a_dict: +27 | | var = a_dict[keys[idx]] +28 | | else: +29 | | var = "default" | |___________________^ SIM401 -28 | -29 | # SIM401 (complex expression in dict) +30 | +31 | dicts = {"key": a_dict} | = help: Replace with `var = a_dict.get(keys[idx], "default")` β„Ή Unsafe fix -21 21 | var = val1 + val2 -22 22 | -23 23 | # SIM401 (complex expression in key) -24 |-if keys[idx] in a_dict: -25 |- var = a_dict[keys[idx]] -26 |-else: -27 |- var = "default" - 24 |+var = a_dict.get(keys[idx], "default") -28 25 | -29 26 | # SIM401 (complex expression in dict) -30 27 | if key in dicts[idx]: +23 23 | var = val1 + val2 +24 24 | +25 25 | # SIM401 (complex expression in key) +26 |-if keys[idx] in a_dict: +27 |- var = a_dict[keys[idx]] +28 |-else: +29 |- var = "default" + 26 |+var = a_dict.get(keys[idx], "default") +30 27 | +31 28 | dicts = {"key": a_dict} +32 29 | -SIM401.py:30:1: SIM401 [*] Use `var = dicts[idx].get(key, "default")` instead of an `if` block +SIM401.py:40:1: SIM401 [*] Use `vars[idx] = a_dict.get(key, "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789")` instead of an `if` block | -29 | # SIM401 (complex expression in dict) -30 | / if key in dicts[idx]: -31 | | var = dicts[idx][key] -32 | | else: -33 | | var = "default" - | |___________________^ SIM401 -34 | -35 | # SIM401 (complex expression in var) - | - = help: Replace with `var = dicts[idx].get(key, "default")` - -β„Ή Unsafe fix -27 27 | var = "default" -28 28 | -29 29 | # SIM401 (complex expression in dict) -30 |-if key in dicts[idx]: -31 |- var = dicts[idx][key] -32 |-else: -33 |- var = "default" - 30 |+var = dicts[idx].get(key, "default") -34 31 | -35 32 | # SIM401 (complex expression in var) -36 33 | if key in a_dict: - -SIM401.py:36:1: SIM401 [*] Use `vars[idx] = a_dict.get(key, "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789")` instead of an `if` block - | -35 | # SIM401 (complex expression in var) -36 | / if key in a_dict: -37 | | vars[idx] = a_dict[key] -38 | | else: -39 | | vars[idx] = "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789" +39 | # SIM401 (complex expression in var) +40 | / if key in a_dict: +41 | | vars[idx] = a_dict[key] +42 | | else: +43 | | vars[idx] = "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789" | |___________________________________________________________________________^ SIM401 -40 | -41 | # SIM401 +44 | +45 | # SIM401 | = help: Replace with `vars[idx] = a_dict.get(key, "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789")` β„Ή Unsafe fix -33 33 | var = "default" -34 34 | -35 35 | # SIM401 (complex expression in var) -36 |-if key in a_dict: -37 |- vars[idx] = a_dict[key] -38 |-else: -39 |- vars[idx] = "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789" - 36 |+vars[idx] = a_dict.get(key, "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789") -40 37 | -41 38 | # SIM401 -42 39 | if foo(): +37 37 | var = "default" +38 38 | +39 39 | # SIM401 (complex expression in var) +40 |-if key in a_dict: +41 |- vars[idx] = a_dict[key] +42 |-else: +43 |- vars[idx] = "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789" + 40 |+vars[idx] = a_dict.get(key, "defaultß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789ß9πŸ’£2ℝ6789") +44 41 | +45 42 | # SIM401 +46 43 | if foo(): -SIM401.py:45:5: SIM401 [*] Use `vars[idx] = a_dict.get(key, "default")` instead of an `if` block +SIM401.py:49:5: SIM401 [*] Use `vars[idx] = a_dict.get(key, "default")` instead of an `if` block | -43 | pass -44 | else: -45 | / if key in a_dict: -46 | | vars[idx] = a_dict[key] -47 | | else: -48 | | vars[idx] = "default" +47 | pass +48 | else: +49 | / if key in a_dict: +50 | | vars[idx] = a_dict[key] +51 | | else: +52 | | vars[idx] = "default" | |_____________________________^ SIM401 -49 | -50 | ### +53 | +54 | ### | = help: Replace with `vars[idx] = a_dict.get(key, "default")` β„Ή Unsafe fix -42 42 | if foo(): -43 43 | pass -44 44 | else: -45 |- if key in a_dict: -46 |- vars[idx] = a_dict[key] -47 |- else: -48 |- vars[idx] = "default" - 45 |+ vars[idx] = a_dict.get(key, "default") -49 46 | -50 47 | ### -51 48 | # Negative cases +46 46 | if foo(): +47 47 | pass +48 48 | else: +49 |- if key in a_dict: +50 |- vars[idx] = a_dict[key] +51 |- else: +52 |- vars[idx] = "default" + 49 |+ vars[idx] = a_dict.get(key, "default") +53 50 | +54 51 | ### +55 52 | # Negative cases -SIM401.py:123:7: SIM401 [*] Use `a_dict.get(key, "default3")` instead of an `if` block +SIM401.py:149:7: SIM401 [*] Use `a_dict.get(key, "default3")` instead of an `if` block | -122 | # SIM401 -123 | var = a_dict[key] if key in a_dict else "default3" +148 | # SIM401 +149 | var = a_dict[key] if key in a_dict else "default3" | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401 -124 | -125 | # SIM401 +150 | +151 | # SIM401 | = help: Replace with `a_dict.get(key, "default3")` β„Ή Unsafe fix -120 120 | ### -121 121 | -122 122 | # SIM401 -123 |-var = a_dict[key] if key in a_dict else "default3" - 123 |+var = a_dict.get(key, "default3") -124 124 | -125 125 | # SIM401 -126 126 | var = "default-1" if key not in a_dict else a_dict[key] +146 146 | ### +147 147 | +148 148 | # SIM401 +149 |-var = a_dict[key] if key in a_dict else "default3" + 149 |+var = a_dict.get(key, "default3") +150 150 | +151 151 | # SIM401 +152 152 | var = "default-1" if key not in a_dict else a_dict[key] -SIM401.py:126:7: SIM401 [*] Use `a_dict.get(key, "default-1")` instead of an `if` block +SIM401.py:152:7: SIM401 [*] Use `a_dict.get(key, "default-1")` instead of an `if` block | -125 | # SIM401 -126 | var = "default-1" if key not in a_dict else a_dict[key] +151 | # SIM401 +152 | var = "default-1" if key not in a_dict else a_dict[key] | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401 -127 | -128 | # OK (default contains effect) +153 | +154 | # OK (default contains effect) | = help: Replace with `a_dict.get(key, "default-1")` β„Ή Unsafe fix -123 123 | var = a_dict[key] if key in a_dict else "default3" -124 124 | -125 125 | # SIM401 -126 |-var = "default-1" if key not in a_dict else a_dict[key] - 126 |+var = a_dict.get(key, "default-1") -127 127 | -128 128 | # OK (default contains effect) -129 129 | var = a_dict[key] if key in a_dict else val1 + val2 +149 149 | var = a_dict[key] if key in a_dict else "default3" +150 150 | +151 151 | # SIM401 +152 |-var = "default-1" if key not in a_dict else a_dict[key] + 152 |+var = a_dict.get(key, "default-1") +153 153 | +154 154 | # OK (default contains effect) +155 155 | var = a_dict[key] if key in a_dict else val1 + val2 diff --git a/crates/ruff_linter/src/rules/ruff/rules/falsy_dict_get_fallback.rs b/crates/ruff_linter/src/rules/ruff/rules/falsy_dict_get_fallback.rs index 18482224fe..1f73cceb29 100644 --- a/crates/ruff_linter/src/rules/ruff/rules/falsy_dict_get_fallback.rs +++ b/crates/ruff_linter/src/rules/ruff/rules/falsy_dict_get_fallback.rs @@ -2,9 +2,8 @@ use crate::checkers::ast::Checker; use crate::fix::edits::{remove_argument, Parentheses}; use ruff_diagnostics::{AlwaysFixableViolation, Applicability, Diagnostic, Fix}; use ruff_macros::{derive_message_formats, ViolationMetadata}; -use ruff_python_ast::{helpers::Truthiness, Expr, ExprAttribute, ExprName}; +use ruff_python_ast::{helpers::Truthiness, Expr, ExprAttribute}; use ruff_python_semantic::analyze::typing; -use ruff_python_semantic::SemanticModel; use ruff_text_size::Ranged; /// ## What it does @@ -69,7 +68,7 @@ pub(crate) fn falsy_dict_get_fallback(checker: &mut Checker, expr: &Expr) { // Check if the object is a dictionary using the semantic model if !value .as_name_expr() - .is_some_and(|name| is_known_to_be_of_type_dict(semantic, name)) + .is_some_and(|name| typing::is_known_to_be_of_type_dict(semantic, name)) { return; } @@ -110,11 +109,3 @@ pub(crate) fn falsy_dict_get_fallback(checker: &mut Checker, expr: &Expr) { checker.diagnostics.push(diagnostic); } - -fn is_known_to_be_of_type_dict(semantic: &SemanticModel, expr: &ExprName) -> bool { - let Some(binding) = semantic.only_binding(expr).map(|id| semantic.binding(id)) else { - return false; - }; - - typing::is_dict(binding, semantic) -} diff --git a/crates/ruff_linter/src/rules/ruff/rules/if_key_in_dict_del.rs b/crates/ruff_linter/src/rules/ruff/rules/if_key_in_dict_del.rs index ac76b68eaa..3b15438ac3 100644 --- a/crates/ruff_linter/src/rules/ruff/rules/if_key_in_dict_del.rs +++ b/crates/ruff_linter/src/rules/ruff/rules/if_key_in_dict_del.rs @@ -3,7 +3,6 @@ use ruff_diagnostics::{AlwaysFixableViolation, Applicability, Diagnostic, Edit, use ruff_macros::{derive_message_formats, ViolationMetadata}; use ruff_python_ast::{CmpOp, Expr, ExprName, ExprSubscript, Stmt, StmtIf}; use ruff_python_semantic::analyze::typing; -use ruff_python_semantic::SemanticModel; type Key = Expr; type Dict = ExprName; @@ -60,7 +59,7 @@ pub(crate) fn if_key_in_dict_del(checker: &mut Checker, stmt: &StmtIf) { return; } - if !is_known_to_be_of_type_dict(checker.semantic(), test_dict) { + if !typing::is_known_to_be_of_type_dict(checker.semantic(), test_dict) { return; } @@ -127,14 +126,6 @@ fn is_same_dict(test: &Dict, del: &Dict) -> bool { test.id.as_str() == del.id.as_str() } -fn is_known_to_be_of_type_dict(semantic: &SemanticModel, dict: &Dict) -> bool { - let Some(binding) = semantic.only_binding(dict).map(|id| semantic.binding(id)) else { - return false; - }; - - typing::is_dict(binding, semantic) -} - fn replace_with_dict_pop_fix(checker: &Checker, stmt: &StmtIf, dict: &Dict, key: &Key) -> Fix { let locator = checker.locator(); let dict_expr = locator.slice(dict); diff --git a/crates/ruff_python_semantic/src/analyze/typing.rs b/crates/ruff_python_semantic/src/analyze/typing.rs index 091ff44673..2b4d644f36 100644 --- a/crates/ruff_python_semantic/src/analyze/typing.rs +++ b/crates/ruff_python_semantic/src/analyze/typing.rs @@ -4,7 +4,8 @@ use ruff_python_ast::helpers::{any_over_expr, is_const_false, map_subscript}; use ruff_python_ast::identifier::Identifier; use ruff_python_ast::name::QualifiedName; use ruff_python_ast::{ - self as ast, Expr, ExprCall, Int, Operator, ParameterWithDefault, Parameters, Stmt, StmtAssign, + self as ast, Expr, ExprCall, ExprName, Int, Operator, ParameterWithDefault, Parameters, Stmt, + StmtAssign, }; use ruff_python_stdlib::typing::{ as_pep_585_generic, has_pep_585_generic, is_immutable_generic_type, @@ -46,6 +47,14 @@ pub enum SubscriptKind { TypedDict, } +pub fn is_known_to_be_of_type_dict(semantic: &SemanticModel, expr: &ExprName) -> bool { + let Some(binding) = semantic.only_binding(expr).map(|id| semantic.binding(id)) else { + return false; + }; + + is_dict(binding, semantic) +} + pub fn match_annotated_subscript<'a>( expr: &Expr, semantic: &SemanticModel,