diff --git a/resources/test/fixtures/flake8_simplify/SIM117.py b/resources/test/fixtures/flake8_simplify/SIM117.py index 8f6d44c24c..aa972ab0f8 100644 --- a/resources/test/fixtures/flake8_simplify/SIM117.py +++ b/resources/test/fixtures/flake8_simplify/SIM117.py @@ -2,6 +2,11 @@ with A() as a: # SIM117 with B() as b: print("hello") +with A(): # SIM117 + with B(): + with C(): + print("hello") + with A() as a: a() with B() as b: diff --git a/src/ast/helpers.rs b/src/ast/helpers.rs index ee68b95ad1..c325910e3f 100644 --- a/src/ast/helpers.rs +++ b/src/ast/helpers.rs @@ -622,6 +622,19 @@ pub fn preceded_by_continuation(stmt: &Stmt, locator: &Locator) -> bool { false } +/// Return the `Range` of the first `Tok::Colon` token in a `Range`. +pub fn first_colon_range(range: Range, locator: &Locator) -> Option { + let contents = locator.slice_source_code_range(&range); + let range = lexer::make_tokenizer_located(&contents, range.location) + .flatten() + .find(|(_, kind, _)| matches!(kind, Tok::Colon)) + .map(|(location, _, end_location)| Range { + location, + end_location, + }); + range +} + /// Return `true` if a `Stmt` appears to be part of a multi-statement line, with /// other statements preceding it. pub fn preceded_by_multi_statement_line(stmt: &Stmt, locator: &Locator) -> bool { @@ -708,7 +721,9 @@ mod tests { use rustpython_ast::Location; use rustpython_parser::parser; - use crate::ast::helpers::{else_range, identifier_range, match_trailing_content}; + use crate::ast::helpers::{ + else_range, first_colon_range, identifier_range, match_trailing_content, + }; use crate::ast::types::Range; use crate::source_code::Locator; @@ -839,4 +854,19 @@ else: assert_eq!(range.end_location.column(), 4); Ok(()) } + + #[test] + fn test_first_colon_range() { + let contents = "with a: pass"; + let locator = Locator::new(contents); + let range = first_colon_range( + Range::new(Location::new(1, 0), Location::new(1, contents.len())), + &locator, + ) + .unwrap(); + assert_eq!(range.location.row(), 1); + assert_eq!(range.location.column(), 6); + assert_eq!(range.end_location.row(), 1); + assert_eq!(range.end_location.column(), 7); + } } diff --git a/src/checkers/ast.rs b/src/checkers/ast.rs index 4436f70bf0..af670b5e1a 100644 --- a/src/checkers/ast.rs +++ b/src/checkers/ast.rs @@ -1283,7 +1283,12 @@ where flake8_pytest_style::rules::complex_raises(self, stmt, items, body); } if self.settings.enabled.contains(&RuleCode::SIM117) { - flake8_simplify::rules::multiple_with_statements(self, stmt); + flake8_simplify::rules::multiple_with_statements( + self, + stmt, + body, + self.current_stmt_parent().map(|parent| parent.0), + ); } } StmtKind::While { body, orelse, .. } => { diff --git a/src/flake8_simplify/rules/ast_with.rs b/src/flake8_simplify/rules/ast_with.rs index 34d3cfaebf..50af0eadc4 100644 --- a/src/flake8_simplify/rules/ast_with.rs +++ b/src/flake8_simplify/rules/ast_with.rs @@ -1,22 +1,51 @@ -use rustpython_ast::{Stmt, StmtKind}; +use rustpython_ast::{Located, Stmt, StmtKind, Withitem}; +use crate::ast::helpers::first_colon_range; use crate::ast::types::Range; use crate::checkers::ast::Checker; use crate::registry::Diagnostic; use crate::violations; +fn find_last_with(body: &[Stmt]) -> Option<(&Vec, &Vec)> { + let [Located { node: StmtKind::With { items, body, .. }, ..}] = body else { return None }; + find_last_with(body).or(Some((items, body))) +} + /// SIM117 -pub fn multiple_with_statements(checker: &mut Checker, stmt: &Stmt) { - let StmtKind::With { body, .. } = &stmt.node else { - return; - }; - if body.len() != 1 { - return; +pub fn multiple_with_statements( + checker: &mut Checker, + with_stmt: &Stmt, + with_body: &[Stmt], + with_parent: Option<&Stmt>, +) { + if let Some(parent) = with_parent { + if let StmtKind::With { body, .. } = &parent.node { + if body.len() == 1 { + return; + } + } } - if matches!(body[0].node, StmtKind::With { .. }) { + if let Some((items, body)) = find_last_with(with_body) { + let last_item = items.last().expect("Expected items to be non-empty"); + let colon = first_colon_range( + Range::new( + last_item + .optional_vars + .as_ref() + .map_or(last_item.context_expr.end_location, |v| v.end_location) + .unwrap(), + body.first() + .expect("Expected body to be non-empty") + .location, + ), + checker.locator, + ); checker.diagnostics.push(Diagnostic::new( violations::MultipleWithStatements, - Range::from_located(stmt), + colon.map_or_else( + || Range::from_located(with_stmt), + |colon| Range::new(with_stmt.location, colon.end_location), + ), )); } } diff --git a/src/flake8_simplify/snapshots/ruff__flake8_simplify__tests__SIM117_SIM117.py.snap b/src/flake8_simplify/snapshots/ruff__flake8_simplify__tests__SIM117_SIM117.py.snap index dd889fbdcc..a3bd3189fe 100644 --- a/src/flake8_simplify/snapshots/ruff__flake8_simplify__tests__SIM117_SIM117.py.snap +++ b/src/flake8_simplify/snapshots/ruff__flake8_simplify__tests__SIM117_SIM117.py.snap @@ -1,6 +1,6 @@ --- source: src/flake8_simplify/mod.rs -expression: checks +expression: diagnostics --- - kind: MultipleWithStatements: ~ @@ -8,8 +8,18 @@ expression: checks row: 1 column: 0 end_location: - row: 3 - column: 22 + row: 2 + column: 18 + fix: ~ + parent: ~ +- kind: + MultipleWithStatements: ~ + location: + row: 5 + column: 0 + end_location: + row: 7 + column: 17 fix: ~ parent: ~