diff --git a/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs b/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs index 2dc19c2cb2..9922f2411e 100644 --- a/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs +++ b/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs @@ -1,6 +1,6 @@ use ruff_diagnostics::Diagnostic; use ruff_python_ast::Ranged; -use ruff_python_semantic::analyze::{branch_detection, visibility}; +use ruff_python_semantic::analyze::visibility; use ruff_python_semantic::{Binding, BindingKind, ScopeKind}; use crate::checkers::ast::Checker; @@ -112,11 +112,7 @@ pub(crate) fn deferred_scopes(checker: &mut Checker) { // If the bindings are in different forks, abort. if shadowed.source.map_or(true, |left| { binding.source.map_or(true, |right| { - branch_detection::different_forks( - left, - right, - checker.semantic.statements(), - ) + checker.semantic.different_branches(left, right) }) }) { continue; @@ -208,11 +204,7 @@ pub(crate) fn deferred_scopes(checker: &mut Checker) { // If the bindings are in different forks, abort. if shadowed.source.map_or(true, |left| { binding.source.map_or(true, |right| { - branch_detection::different_forks( - left, - right, - checker.semantic.statements(), - ) + checker.semantic.different_branches(left, right) }) }) { continue; diff --git a/crates/ruff/src/checkers/ast/analyze/statement.rs b/crates/ruff/src/checkers/ast/analyze/statement.rs index 5b6892f8a7..0a5c1c22f3 100644 --- a/crates/ruff/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff/src/checkers/ast/analyze/statement.rs @@ -464,17 +464,17 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { flake8_pyi::rules::pass_statement_stub_body(checker, body); } if checker.enabled(Rule::PassInClassBody) { - flake8_pyi::rules::pass_in_class_body(checker, stmt, body); + flake8_pyi::rules::pass_in_class_body(checker, class_def); } } if checker.enabled(Rule::EllipsisInNonEmptyClassBody) { - flake8_pyi::rules::ellipsis_in_non_empty_class_body(checker, stmt, body); + flake8_pyi::rules::ellipsis_in_non_empty_class_body(checker, body); } if checker.enabled(Rule::PytestIncorrectMarkParenthesesStyle) { flake8_pytest_style::rules::marks(checker, decorator_list); } if checker.enabled(Rule::DuplicateClassFieldDefinition) { - flake8_pie::rules::duplicate_class_field_definition(checker, stmt, body); + flake8_pie::rules::duplicate_class_field_definition(checker, body); } if checker.enabled(Rule::NonUniqueEnums) { flake8_pie::rules::non_unique_enums(checker, stmt, body); diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index df93c18b11..8d86e48867 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -32,8 +32,8 @@ use itertools::Itertools; use log::error; use ruff_python_ast::{ self as ast, Arguments, Comprehension, Constant, ElifElseClause, ExceptHandler, Expr, - ExprContext, Keyword, Parameter, ParameterWithDefault, Parameters, Pattern, Ranged, Stmt, - Suite, UnaryOp, + ExprContext, Keyword, MatchCase, Parameter, ParameterWithDefault, Parameters, Pattern, Ranged, + Stmt, Suite, UnaryOp, }; use ruff_text_size::{TextRange, TextSize}; @@ -193,18 +193,22 @@ impl<'a> Checker<'a> { } } - /// Returns the [`IsolationLevel`] for fixes in the current context. + /// Returns the [`IsolationLevel`] to isolate fixes for the current statement. /// /// The primary use-case for fix isolation is to ensure that we don't delete all statements /// in a given indented block, which would cause a syntax error. We therefore need to ensure /// that we delete at most one statement per indented block per fixer pass. Fix isolation should /// thus be applied whenever we delete a statement, but can otherwise be omitted. - pub(crate) fn isolation(&self, parent: Option<&Stmt>) -> IsolationLevel { - parent - .and_then(|stmt| self.semantic.statement_id(stmt)) - .map_or(IsolationLevel::default(), |node_id| { - IsolationLevel::Group(node_id.into()) - }) + pub(crate) fn statement_isolation(&self) -> IsolationLevel { + IsolationLevel::Group(self.semantic.current_statement_id().into()) + } + + /// Returns the [`IsolationLevel`] to isolate fixes in the current statement's parent. + pub(crate) fn parent_isolation(&self) -> IsolationLevel { + self.semantic + .current_statement_parent_id() + .map(|node_id| IsolationLevel::Group(node_id.into())) + .unwrap_or_default() } /// The [`Locator`] for the current file, which enables extraction of source code from byte @@ -619,16 +623,28 @@ where } } + // Iterate over the `body`, then the `handlers`, then the `orelse`, then the + // `finalbody`, but treat the body and the `orelse` as a single branch for + // flow analysis purposes. + let branch = self.semantic.push_branch(); self.semantic.handled_exceptions.push(handled_exceptions); self.visit_body(body); self.semantic.handled_exceptions.pop(); + self.semantic.pop_branch(); for except_handler in handlers { + self.semantic.push_branch(); self.visit_except_handler(except_handler); + self.semantic.pop_branch(); } + self.semantic.set_branch(branch); self.visit_body(orelse); + self.semantic.pop_branch(); + + self.semantic.push_branch(); self.visit_body(finalbody); + self.semantic.pop_branch(); } Stmt::AnnAssign(ast::StmtAnnAssign { target, @@ -708,6 +724,7 @@ where ) => { self.visit_boolean_test(test); + self.semantic.push_branch(); if typing::is_type_checking_block(stmt_if, &self.semantic) { if self.semantic.at_top_level() { self.importer.visit_type_checking_block(stmt); @@ -716,9 +733,12 @@ where } else { self.visit_body(body); } + self.semantic.pop_branch(); for clause in elif_else_clauses { + self.semantic.push_branch(); self.visit_elif_else_clause(clause); + self.semantic.pop_branch(); } } _ => visitor::walk_stmt(self, stmt), @@ -1353,6 +1373,17 @@ where } } + fn visit_match_case(&mut self, match_case: &'b MatchCase) { + self.visit_pattern(&match_case.pattern); + if let Some(expr) = &match_case.guard { + self.visit_expr(expr); + } + + self.semantic.push_branch(); + self.visit_body(&match_case.body); + self.semantic.pop_branch(); + } + fn visit_type_param(&mut self, type_param: &'b ast::TypeParam) { // Step 1: Binding match type_param { diff --git a/crates/ruff/src/rules/flake8_pie/rules/duplicate_class_field_definition.rs b/crates/ruff/src/rules/flake8_pie/rules/duplicate_class_field_definition.rs index 34cd5190b0..922f40ec19 100644 --- a/crates/ruff/src/rules/flake8_pie/rules/duplicate_class_field_definition.rs +++ b/crates/ruff/src/rules/flake8_pie/rules/duplicate_class_field_definition.rs @@ -1,9 +1,9 @@ -use ruff_python_ast::{self as ast, Expr, Ranged, Stmt}; use rustc_hash::FxHashSet; use ruff_diagnostics::Diagnostic; use ruff_diagnostics::{AlwaysAutofixableViolation, Fix}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::{self as ast, Expr, Ranged, Stmt}; use crate::autofix; use crate::checkers::ast::Checker; @@ -49,11 +49,7 @@ impl AlwaysAutofixableViolation for DuplicateClassFieldDefinition { } /// PIE794 -pub(crate) fn duplicate_class_field_definition( - checker: &mut Checker, - parent: &Stmt, - body: &[Stmt], -) { +pub(crate) fn duplicate_class_field_definition(checker: &mut Checker, body: &[Stmt]) { let mut seen_targets: FxHashSet<&str> = FxHashSet::default(); for stmt in body { // Extract the property name from the assignment statement. @@ -85,11 +81,11 @@ pub(crate) fn duplicate_class_field_definition( if checker.patch(diagnostic.kind.rule()) { let edit = autofix::edits::delete_stmt( stmt, - Some(parent), + Some(stmt), checker.locator(), checker.indexer(), ); - diagnostic.set_fix(Fix::suggested(edit).isolate(checker.isolation(Some(parent)))); + diagnostic.set_fix(Fix::suggested(edit).isolate(checker.statement_isolation())); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/flake8_pyi/rules/ellipsis_in_non_empty_class_body.rs b/crates/ruff/src/rules/flake8_pyi/rules/ellipsis_in_non_empty_class_body.rs index 2462b2ad98..51c21e105e 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/ellipsis_in_non_empty_class_body.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/ellipsis_in_non_empty_class_body.rs @@ -1,7 +1,6 @@ -use ruff_python_ast::{Expr, ExprConstant, Ranged, Stmt, StmtExpr}; - use ruff_diagnostics::{AutofixKind, Diagnostic, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::{Constant, Expr, ExprConstant, Ranged, Stmt, StmtExpr}; use crate::autofix; use crate::checkers::ast::Checker; @@ -44,11 +43,7 @@ impl Violation for EllipsisInNonEmptyClassBody { } /// PYI013 -pub(crate) fn ellipsis_in_non_empty_class_body( - checker: &mut Checker, - parent: &Stmt, - body: &[Stmt], -) { +pub(crate) fn ellipsis_in_non_empty_class_body(checker: &mut Checker, body: &[Stmt]) { // If the class body contains a single statement, then it's fine for it to be an ellipsis. if body.len() == 1 { return; @@ -59,24 +54,24 @@ pub(crate) fn ellipsis_in_non_empty_class_body( continue; }; - let Expr::Constant(ExprConstant { value, .. }) = value.as_ref() else { - continue; - }; - - if !value.is_ellipsis() { - continue; + if matches!( + value.as_ref(), + Expr::Constant(ExprConstant { + value: Constant::Ellipsis, + .. + }) + ) { + let mut diagnostic = Diagnostic::new(EllipsisInNonEmptyClassBody, stmt.range()); + if checker.patch(diagnostic.kind.rule()) { + let edit = autofix::edits::delete_stmt( + stmt, + Some(stmt), + checker.locator(), + checker.indexer(), + ); + diagnostic.set_fix(Fix::automatic(edit).isolate(checker.statement_isolation())); + } + checker.diagnostics.push(diagnostic); } - - let mut diagnostic = Diagnostic::new(EllipsisInNonEmptyClassBody, stmt.range()); - if checker.patch(diagnostic.kind.rule()) { - let edit = autofix::edits::delete_stmt( - stmt, - Some(parent), - checker.locator(), - checker.indexer(), - ); - diagnostic.set_fix(Fix::automatic(edit).isolate(checker.isolation(Some(parent)))); - } - checker.diagnostics.push(diagnostic); } } diff --git a/crates/ruff/src/rules/flake8_pyi/rules/pass_in_class_body.rs b/crates/ruff/src/rules/flake8_pyi/rules/pass_in_class_body.rs index a16d7437f8..fd18b84174 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/pass_in_class_body.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/pass_in_class_body.rs @@ -1,7 +1,6 @@ -use ruff_python_ast::{Ranged, Stmt}; - use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Fix}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::{self as ast, Ranged}; use crate::autofix; use crate::checkers::ast::Checker; @@ -22,26 +21,22 @@ impl AlwaysAutofixableViolation for PassInClassBody { } /// PYI012 -pub(crate) fn pass_in_class_body(checker: &mut Checker, parent: &Stmt, body: &[Stmt]) { +pub(crate) fn pass_in_class_body(checker: &mut Checker, class_def: &ast::StmtClassDef) { // `pass` is required in these situations (or handled by `pass_statement_stub_body`). - if body.len() < 2 { + if class_def.body.len() < 2 { return; } - for stmt in body { + for stmt in &class_def.body { if !stmt.is_pass_stmt() { continue; } let mut diagnostic = Diagnostic::new(PassInClassBody, stmt.range()); if checker.patch(diagnostic.kind.rule()) { - let edit = autofix::edits::delete_stmt( - stmt, - Some(parent), - checker.locator(), - checker.indexer(), - ); - diagnostic.set_fix(Fix::automatic(edit).isolate(checker.isolation(Some(parent)))); + let edit = + autofix::edits::delete_stmt(stmt, Some(stmt), checker.locator(), checker.indexer()); + diagnostic.set_fix(Fix::automatic(edit).isolate(checker.statement_isolation())); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/flake8_pyi/rules/pass_statement_stub_body.rs b/crates/ruff/src/rules/flake8_pyi/rules/pass_statement_stub_body.rs index 769cfcfcf9..b9d957b819 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/pass_statement_stub_body.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/pass_statement_stub_body.rs @@ -1,7 +1,6 @@ -use ruff_python_ast::{Ranged, Stmt}; - use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::{Ranged, Stmt}; use crate::checkers::ast::Checker; use crate::registry::Rule; @@ -22,15 +21,15 @@ impl AlwaysAutofixableViolation for PassStatementStubBody { /// PYI009 pub(crate) fn pass_statement_stub_body(checker: &mut Checker, body: &[Stmt]) { - if body.len() != 1 { + let [stmt] = body else { return; - } - if body[0].is_pass_stmt() { - let mut diagnostic = Diagnostic::new(PassStatementStubBody, body[0].range()); + }; + if stmt.is_pass_stmt() { + let mut diagnostic = Diagnostic::new(PassStatementStubBody, stmt.range()); if checker.patch(Rule::PassStatementStubBody) { diagnostic.set_fix(Fix::automatic(Edit::range_replacement( format!("..."), - body[0].range(), + stmt.range(), ))); }; checker.diagnostics.push(diagnostic); diff --git a/crates/ruff/src/rules/flake8_pyi/rules/str_or_repr_defined_in_stub.rs b/crates/ruff/src/rules/flake8_pyi/rules/str_or_repr_defined_in_stub.rs index ad3b7cf470..203c201f7b 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/str_or_repr_defined_in_stub.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/str_or_repr_defined_in_stub.rs @@ -99,10 +99,7 @@ pub(crate) fn str_or_repr_defined_in_stub(checker: &mut Checker, stmt: &Stmt) { let stmt = checker.semantic().current_statement(); let parent = checker.semantic().current_statement_parent(); let edit = delete_stmt(stmt, parent, checker.locator(), checker.indexer()); - diagnostic.set_fix( - Fix::automatic(edit) - .isolate(checker.isolation(checker.semantic().current_statement_parent())), - ); + diagnostic.set_fix(Fix::automatic(edit).isolate(checker.parent_isolation())); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs b/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs index 05b3f2cb1c..6d249e2d15 100644 --- a/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs +++ b/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs @@ -61,7 +61,7 @@ pub(crate) fn empty_type_checking_block(checker: &mut Checker, stmt: &ast::StmtI let stmt = checker.semantic().current_statement(); let parent = checker.semantic().current_statement_parent(); let edit = autofix::edits::delete_stmt(stmt, parent, checker.locator(), checker.indexer()); - diagnostic.set_fix(Fix::automatic(edit).isolate(checker.isolation(parent))); + diagnostic.set_fix(Fix::automatic(edit).isolate(checker.parent_isolation())); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs b/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs index 0dd50ab22e..28be0841c1 100644 --- a/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs +++ b/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs @@ -244,6 +244,6 @@ fn fix_imports( Ok( Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits()) - .isolate(checker.isolation(parent)), + .isolate(checker.parent_isolation()), ) } diff --git a/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs b/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs index f9fc0a42a7..c1d282d3ce 100644 --- a/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs +++ b/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs @@ -491,6 +491,6 @@ fn fix_imports( Ok( Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits()) - .isolate(checker.isolation(parent)), + .isolate(checker.parent_isolation()), ) } diff --git a/crates/ruff/src/rules/pyflakes/rules/unused_import.rs b/crates/ruff/src/rules/pyflakes/rules/unused_import.rs index 72e57ee1c4..94af0eb2db 100644 --- a/crates/ruff/src/rules/pyflakes/rules/unused_import.rs +++ b/crates/ruff/src/rules/pyflakes/rules/unused_import.rs @@ -256,5 +256,5 @@ fn fix_imports( checker.stylist(), checker.indexer(), )?; - Ok(Fix::automatic(edit).isolate(checker.isolation(parent))) + Ok(Fix::automatic(edit).isolate(checker.parent_isolation())) } diff --git a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs index 53e176730e..bae464067c 100644 --- a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs +++ b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs @@ -1,6 +1,6 @@ use itertools::Itertools; -use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation}; +use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, IsolationLevel, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::contains_effect; use ruff_python_ast::{self as ast, PySourceType, Ranged, Stmt}; @@ -207,6 +207,7 @@ fn remove_unused_variable( stmt: &Stmt, parent: Option<&Stmt>, range: TextRange, + isolation: IsolationLevel, checker: &Checker, ) -> Option { // First case: simple assignment (`x = 1`) @@ -229,7 +230,7 @@ fn remove_unused_variable( } else { // If (e.g.) assigning to a constant (`x = 1`), delete the entire statement. let edit = delete_stmt(stmt, parent, checker.locator(), checker.indexer()); - Some(Fix::suggested(edit).isolate(checker.isolation(parent))) + Some(Fix::suggested(edit).isolate(isolation)) }; } } @@ -257,7 +258,7 @@ fn remove_unused_variable( } else { // If (e.g.) assigning to a constant (`x = 1`), delete the entire statement. let edit = delete_stmt(stmt, parent, checker.locator(), checker.indexer()); - Some(Fix::suggested(edit).isolate(checker.isolation(parent))) + Some(Fix::suggested(edit).isolate(isolation)) }; } } @@ -331,7 +332,14 @@ pub(crate) fn unused_variable(checker: &Checker, scope: &Scope, diagnostics: &mu if let Some(statement_id) = source { let statement = checker.semantic().statement(statement_id); let parent = checker.semantic().parent_statement(statement_id); - if let Some(fix) = remove_unused_variable(statement, parent, range, checker) { + let isolation = checker + .semantic() + .parent_statement_id(statement_id) + .map(|node_id| IsolationLevel::Group(node_id.into())) + .unwrap_or_default(); + if let Some(fix) = + remove_unused_variable(statement, parent, range, isolation, checker) + { diagnostic.set_fix(fix); } } diff --git a/crates/ruff/src/rules/pylint/rules/useless_return.rs b/crates/ruff/src/rules/pylint/rules/useless_return.rs index d38cf9ae39..f3a6457fd1 100644 --- a/crates/ruff/src/rules/pylint/rules/useless_return.rs +++ b/crates/ruff/src/rules/pylint/rules/useless_return.rs @@ -109,7 +109,7 @@ pub(crate) fn useless_return( checker.locator(), checker.indexer(), ); - diagnostic.set_fix(Fix::automatic(edit).isolate(checker.isolation(Some(stmt)))); + diagnostic.set_fix(Fix::automatic(edit).isolate(checker.statement_isolation())); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs index bf7e6724b2..ba9ab76cd5 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs @@ -135,7 +135,7 @@ pub(crate) fn unnecessary_builtin_import( checker.stylist(), checker.indexer(), )?; - Ok(Fix::suggested(edit).isolate(checker.isolation(parent))) + Ok(Fix::suggested(edit).isolate(checker.parent_isolation())) }); } checker.diagnostics.push(diagnostic); diff --git a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs index 2cbf8dbe06..b62de683be 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs @@ -124,7 +124,7 @@ pub(crate) fn unnecessary_future_import(checker: &mut Checker, stmt: &Stmt, name checker.stylist(), checker.indexer(), )?; - Ok(Fix::suggested(edit).isolate(checker.isolation(parent))) + Ok(Fix::suggested(edit).isolate(checker.parent_isolation())) }); } checker.diagnostics.push(diagnostic); diff --git a/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs b/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs index c0ee032ef9..f23fd8339a 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs @@ -66,7 +66,7 @@ pub(crate) fn useless_metaclass_type( let stmt = checker.semantic().current_statement(); let parent = checker.semantic().current_statement_parent(); let edit = autofix::edits::delete_stmt(stmt, parent, checker.locator(), checker.indexer()); - diagnostic.set_fix(Fix::automatic(edit).isolate(checker.isolation(parent))); + diagnostic.set_fix(Fix::automatic(edit).isolate(checker.parent_isolation())); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff_python_semantic/src/analyze/branch_detection.rs b/crates/ruff_python_semantic/src/analyze/branch_detection.rs deleted file mode 100644 index e7bbf1b428..0000000000 --- a/crates/ruff_python_semantic/src/analyze/branch_detection.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::cmp::Ordering; -use std::iter; - -use ruff_python_ast::{self as ast, ExceptHandler, Stmt}; - -use crate::statements::{StatementId, Statements}; - -/// Return the common ancestor of `left` and `right` below `stop`, or `None`. -fn common_ancestor( - left: StatementId, - right: StatementId, - stop: Option, - node_tree: &Statements, -) -> Option { - if stop.is_some_and(|stop| left == stop || right == stop) { - return None; - } - - if left == right { - return Some(left); - } - - let left_depth = node_tree.depth(left); - let right_depth = node_tree.depth(right); - - match left_depth.cmp(&right_depth) { - Ordering::Less => { - let right = node_tree.parent_id(right)?; - common_ancestor(left, right, stop, node_tree) - } - Ordering::Equal => { - let left = node_tree.parent_id(left)?; - let right = node_tree.parent_id(right)?; - common_ancestor(left, right, stop, node_tree) - } - Ordering::Greater => { - let left = node_tree.parent_id(left)?; - common_ancestor(left, right, stop, node_tree) - } - } -} - -/// Return the alternative branches for a given node. -fn alternatives(stmt: &Stmt) -> Vec> { - match stmt { - Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => iter::once(body.iter().collect()) - .chain( - elif_else_clauses - .iter() - .map(|clause| clause.body.iter().collect()), - ) - .collect(), - Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - .. - }) => vec![body.iter().chain(orelse.iter()).collect()] - .into_iter() - .chain(handlers.iter().map(|handler| { - let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { body, .. }) = - handler; - body.iter().collect() - })) - .collect(), - Stmt::Match(ast::StmtMatch { cases, .. }) => cases - .iter() - .map(|case| case.body.iter().collect()) - .collect(), - _ => vec![], - } -} - -/// Return `true` if `stmt` is a descendent of any of the nodes in `ancestors`. -fn descendant_of<'a>( - stmt: StatementId, - ancestors: &[&'a Stmt], - stop: StatementId, - node_tree: &Statements<'a>, -) -> bool { - ancestors.iter().any(|ancestor| { - node_tree.statement_id(ancestor).is_some_and(|ancestor| { - common_ancestor(stmt, ancestor, Some(stop), node_tree).is_some() - }) - }) -} - -/// Return `true` if `left` and `right` are on different branches of an `if` or -/// `try` statement. -pub fn different_forks(left: StatementId, right: StatementId, node_tree: &Statements) -> bool { - if let Some(ancestor) = common_ancestor(left, right, None, node_tree) { - for items in alternatives(node_tree[ancestor]) { - let l = descendant_of(left, &items, ancestor, node_tree); - let r = descendant_of(right, &items, ancestor, node_tree); - if l ^ r { - return true; - } - } - } - false -} diff --git a/crates/ruff_python_semantic/src/analyze/mod.rs b/crates/ruff_python_semantic/src/analyze/mod.rs index f8cb066480..941309a526 100644 --- a/crates/ruff_python_semantic/src/analyze/mod.rs +++ b/crates/ruff_python_semantic/src/analyze/mod.rs @@ -1,4 +1,3 @@ -pub mod branch_detection; pub mod function_type; pub mod logging; pub mod type_inference; diff --git a/crates/ruff_python_semantic/src/branches.rs b/crates/ruff_python_semantic/src/branches.rs new file mode 100644 index 0000000000..477f5522a6 --- /dev/null +++ b/crates/ruff_python_semantic/src/branches.rs @@ -0,0 +1,52 @@ +use std::ops::Index; + +use ruff_index::{newtype_index, IndexVec}; + +/// ID uniquely identifying a branch in a program. +/// +/// For example, given: +/// ```python +/// if x > 0: +/// pass +/// elif x > 1: +/// pass +/// else: +/// pass +/// ``` +/// +/// Each of the three arms of the `if`-`elif`-`else` would be considered a branch, and would be +/// assigned their own unique [`BranchId`]. +#[newtype_index] +#[derive(Ord, PartialOrd)] +pub struct BranchId; + +/// The branches of a program indexed by [`BranchId`] +#[derive(Debug, Default)] +pub(crate) struct Branches(IndexVec>); + +impl Branches { + /// Inserts a new branch into the vector and returns its unique [`BranchID`]. + pub(crate) fn insert(&mut self, parent: Option) -> BranchId { + self.0.push(parent) + } + + /// Return the [`BranchId`] of the parent branch. + #[inline] + pub(crate) fn parent_id(&self, node_id: BranchId) -> Option { + self.0[node_id] + } + + /// Returns an iterator over all [`BranchId`] ancestors, starting from the given [`BranchId`]. + pub(crate) fn ancestor_ids(&self, node_id: BranchId) -> impl Iterator + '_ { + std::iter::successors(Some(node_id), |&node_id| self.0[node_id]) + } +} + +impl Index for Branches { + type Output = Option; + + #[inline] + fn index(&self, index: BranchId) -> &Self::Output { + &self.0[index] + } +} diff --git a/crates/ruff_python_semantic/src/expressions.rs b/crates/ruff_python_semantic/src/expressions.rs index 344e6d073b..47a012013c 100644 --- a/crates/ruff_python_semantic/src/expressions.rs +++ b/crates/ruff_python_semantic/src/expressions.rs @@ -23,20 +23,18 @@ struct ExpressionWithParent<'a> { /// The nodes of a program indexed by [`ExpressionId`] #[derive(Debug, Default)] -pub struct Expressions<'a> { - nodes: IndexVec>, -} +pub struct Expressions<'a>(IndexVec>); impl<'a> Expressions<'a> { /// Inserts a new expression into the node tree and returns its unique id. pub(crate) fn insert(&mut self, node: &'a Expr, parent: Option) -> ExpressionId { - self.nodes.push(ExpressionWithParent { node, parent }) + self.0.push(ExpressionWithParent { node, parent }) } /// Return the [`ExpressionId`] of the parent node. #[inline] pub fn parent_id(&self, node_id: ExpressionId) -> Option { - self.nodes[node_id].parent + self.0[node_id].parent } /// Returns an iterator over all [`ExpressionId`] ancestors, starting from the given [`ExpressionId`]. @@ -44,7 +42,7 @@ impl<'a> Expressions<'a> { &self, node_id: ExpressionId, ) -> impl Iterator + '_ { - std::iter::successors(Some(node_id), |&node_id| self.nodes[node_id].parent) + std::iter::successors(Some(node_id), |&node_id| self.0[node_id].parent) } } @@ -53,6 +51,6 @@ impl<'a> Index for Expressions<'a> { #[inline] fn index(&self, index: ExpressionId) -> &Self::Output { - &self.nodes[index].node + &self.0[index].node } } diff --git a/crates/ruff_python_semantic/src/lib.rs b/crates/ruff_python_semantic/src/lib.rs index 7fd04143d4..6b75e56430 100644 --- a/crates/ruff_python_semantic/src/lib.rs +++ b/crates/ruff_python_semantic/src/lib.rs @@ -1,5 +1,6 @@ pub mod analyze; mod binding; +mod branches; mod context; mod definition; mod expressions; @@ -11,6 +12,7 @@ mod star_import; mod statements; pub use binding::*; +pub use branches::*; pub use context::*; pub use definition::*; pub use expressions::*; diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index 1221883920..85e084843d 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -15,6 +15,7 @@ use crate::binding::{ Binding, BindingFlags, BindingId, BindingKind, Bindings, Exceptions, FromImport, Import, SubmoduleImport, }; +use crate::branches::{BranchId, Branches}; use crate::context::ExecutionContext; use crate::definition::{Definition, DefinitionId, Definitions, Member, Module}; use crate::expressions::{ExpressionId, Expressions}; @@ -32,18 +33,24 @@ pub struct SemanticModel<'a> { typing_modules: &'a [String], module_path: Option<&'a [String]>, - /// Stack of all visited statements. + /// Stack of statements in the program. statements: Statements<'a>, - /// The identifier of the current statement. + /// The ID of the current statement. statement_id: Option, - /// Stack of all visited expressions. + /// Stack of expressions in the program. expressions: Expressions<'a>, - /// The identifier of the current expression. + /// The ID of the current expression. expression_id: Option, + /// Stack of all branches in the program. + branches: Branches, + + /// The ID of the current branch. + branch_id: Option, + /// Stack of all scopes, along with the identifier of the current scope. pub scopes: Scopes<'a>, pub scope_id: ScopeId, @@ -138,6 +145,8 @@ impl<'a> SemanticModel<'a> { statement_id: None, expressions: Expressions::default(), expression_id: None, + branch_id: None, + branches: Branches::default(), scopes: Scopes::default(), scope_id: ScopeId::global(), definitions: Definitions::for_module(module), @@ -781,7 +790,10 @@ impl<'a> SemanticModel<'a> { /// Push a [`Stmt`] onto the stack. pub fn push_statement(&mut self, stmt: &'a Stmt) { - self.statement_id = Some(self.statements.insert(stmt, self.statement_id)); + self.statement_id = Some( + self.statements + .insert(stmt, self.statement_id, self.branch_id), + ); } /// Pop the current [`Stmt`] off the stack. @@ -831,54 +843,78 @@ impl<'a> SemanticModel<'a> { self.definition_id = member.parent; } - /// Return the current `Stmt`. - pub fn current_statement(&self) -> &'a Stmt { - let node_id = self.statement_id.expect("No current statement"); - self.statements[node_id] + /// Push a new branch onto the stack, returning its [`BranchId`]. + pub fn push_branch(&mut self) -> Option { + self.branch_id = Some(self.branches.insert(self.branch_id)); + self.branch_id + } + + /// Pop the current [`BranchId`] off the stack. + pub fn pop_branch(&mut self) { + let node_id = self.branch_id.expect("Attempted to pop without branch"); + self.branch_id = self.branches.parent_id(node_id); + } + + /// Set the current [`BranchId`]. + pub fn set_branch(&mut self, branch_id: Option) { + self.branch_id = branch_id; + } + + /// Returns an [`Iterator`] over the current statement hierarchy represented as [`StatementId`], + /// from the current [`StatementId`] through to any parents. + pub fn current_statement_ids(&self) -> impl Iterator + '_ { + self.statement_id + .iter() + .flat_map(|id| self.statements.ancestor_ids(*id)) } /// Returns an [`Iterator`] over the current statement hierarchy, from the current [`Stmt`] /// through to any parents. pub fn current_statements(&self) -> impl Iterator + '_ { - self.statement_id - .iter() - .flat_map(|id| { - self.statements - .ancestor_ids(*id) - .map(|id| &self.statements[id]) - }) - .copied() + self.current_statement_ids().map(|id| self.statements[id]) } - /// Return the parent `Stmt` of the current `Stmt`, if any. + /// Return the [`StatementId`] of the current [`Stmt`]. + pub fn current_statement_id(&self) -> StatementId { + self.statement_id.expect("No current statement") + } + + /// Return the [`StatementId`] of the current [`Stmt`] parent, if any. + pub fn current_statement_parent_id(&self) -> Option { + self.current_statement_ids().nth(1) + } + + /// Return the current [`Stmt`]. + pub fn current_statement(&self) -> &'a Stmt { + let node_id = self.statement_id.expect("No current statement"); + self.statements[node_id] + } + + /// Return the parent [`Stmt`] of the current [`Stmt`], if any. pub fn current_statement_parent(&self) -> Option<&'a Stmt> { self.current_statements().nth(1) } - /// Return the grandparent `Stmt` of the current `Stmt`, if any. - pub fn current_statement_grandparent(&self) -> Option<&'a Stmt> { - self.current_statements().nth(2) + /// Returns an [`Iterator`] over the current expression hierarchy represented as + /// [`ExpressionId`], from the current [`Expr`] through to any parents. + pub fn current_expression_ids(&self) -> impl Iterator + '_ { + self.expression_id + .iter() + .flat_map(|id| self.expressions.ancestor_ids(*id)) } - /// Return the current `Expr`. + /// Returns an [`Iterator`] over the current expression hierarchy, from the current [`Expr`] + /// through to any parents. + pub fn current_expressions(&self) -> impl Iterator + '_ { + self.current_expression_ids().map(|id| self.expressions[id]) + } + + /// Return the current [`Expr`]. pub fn current_expression(&self) -> Option<&'a Expr> { let node_id = self.expression_id?; Some(self.expressions[node_id]) } - /// Returns an [`Iterator`] over the current statement hierarchy, from the current [`Expr`] - /// through to any parents. - pub fn current_expressions(&self) -> impl Iterator + '_ { - self.expression_id - .iter() - .flat_map(|id| { - self.expressions - .ancestor_ids(*id) - .map(|id| &self.expressions[id]) - }) - .copied() - } - /// Return the parent [`Expr`] of the current [`Expr`], if any. pub fn current_expression_parent(&self) -> Option<&'a Expr> { self.current_expressions().nth(1) @@ -937,17 +973,6 @@ impl<'a> SemanticModel<'a> { None } - /// Return the [`Statements`] vector of all statements. - pub const fn statements(&self) -> &Statements<'a> { - &self.statements - } - - /// Return the [`StatementId`] corresponding to the given [`Stmt`]. - #[inline] - pub fn statement_id(&self, statement: &Stmt) -> Option { - self.statements.statement_id(statement) - } - /// Return the [`Stmt]` corresponding to the given [`StatementId`]. #[inline] pub fn statement(&self, statement_id: StatementId) -> &'a Stmt { @@ -962,6 +987,12 @@ impl<'a> SemanticModel<'a> { .map(|id| self.statements[id]) } + /// Given a [`StatementId`], return the ID of its parent statement, if any. + #[inline] + pub fn parent_statement_id(&self, statement_id: StatementId) -> Option { + self.statements.parent_id(statement_id) + } + /// Set the [`Globals`] for the current [`Scope`]. pub fn set_globals(&mut self, globals: Globals<'a>) { // If any global bindings don't already exist in the global scope, add them. @@ -1066,6 +1097,33 @@ impl<'a> SemanticModel<'a> { false } + /// Returns `true` if `left` and `right` are on different branches of an `if`, `match`, or + /// `try` statement. + /// + /// This implementation assumes that the statements are in the same scope. + pub fn different_branches(&self, left: StatementId, right: StatementId) -> bool { + // Collect the branch path for the left statement. + let left = self + .statements + .branch_id(left) + .iter() + .flat_map(|branch_id| self.branches.ancestor_ids(*branch_id)) + .collect::>(); + + // Collect the branch path for the right statement. + let right = self + .statements + .branch_id(right) + .iter() + .flat_map(|branch_id| self.branches.ancestor_ids(*branch_id)) + .collect::>(); + + !left + .iter() + .zip(right.iter()) + .all(|(left, right)| left == right) + } + /// Returns `true` if the given [`BindingId`] is used. pub fn is_used(&self, binding_id: BindingId) -> bool { self.bindings[binding_id].is_used() diff --git a/crates/ruff_python_semantic/src/statements.rs b/crates/ruff_python_semantic/src/statements.rs index 8295a04a7d..0b3f054aed 100644 --- a/crates/ruff_python_semantic/src/statements.rs +++ b/crates/ruff_python_semantic/src/statements.rs @@ -1,10 +1,9 @@ use std::ops::Index; -use rustc_hash::FxHashMap; - use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast::{Ranged, Stmt}; -use ruff_text_size::TextSize; +use ruff_python_ast::Stmt; + +use crate::branches::BranchId; /// Id uniquely identifying a statement AST node. /// @@ -22,75 +21,44 @@ struct StatementWithParent<'a> { statement: &'a Stmt, /// The ID of the parent of this node, if any. parent: Option, - /// The depth of this node in the tree. - depth: u32, + /// The branch ID of this node, if any. + branch: Option, } /// The statements of a program indexed by [`StatementId`] #[derive(Debug, Default)] -pub struct Statements<'a> { - statements: IndexVec>, - statement_to_id: FxHashMap, -} - -/// A unique key for a statement AST node. No two statements can appear at the same location -/// in the source code, since compound statements must be delimited by _at least_ one character -/// (a colon), so the starting offset is a cheap and sufficient unique identifier. -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct StatementKey(TextSize); - -impl From<&Stmt> for StatementKey { - fn from(statement: &Stmt) -> Self { - Self(statement.start()) - } -} +pub struct Statements<'a>(IndexVec>); impl<'a> Statements<'a> { /// Inserts a new statement into the statement vector and returns its unique ID. - /// - /// Panics if a statement with the same pointer already exists. pub(crate) fn insert( &mut self, statement: &'a Stmt, parent: Option, + branch: Option, ) -> StatementId { - let next_id = self.statements.next_index(); - if let Some(existing_id) = self - .statement_to_id - .insert(StatementKey::from(statement), next_id) - { - panic!("Statements already exists with ID: {existing_id:?}"); - } - self.statements.push(StatementWithParent { + self.0.push(StatementWithParent { statement, parent, - depth: parent.map_or(0, |parent| self.statements[parent].depth + 1), + branch, }) } - /// Returns the [`StatementId`] of the given statement. - #[inline] - pub fn statement_id(&self, statement: &'a Stmt) -> Option { - self.statement_to_id - .get(&StatementKey::from(statement)) - .copied() - } - /// Return the [`StatementId`] of the parent statement. #[inline] - pub fn parent_id(&self, statement_id: StatementId) -> Option { - self.statements[statement_id].parent + pub(crate) fn parent_id(&self, statement_id: StatementId) -> Option { + self.0[statement_id].parent } - /// Return the depth of the statement. + /// Return the [`StatementId`] of the parent statement. #[inline] - pub(crate) fn depth(&self, id: StatementId) -> u32 { - self.statements[id].depth + pub(crate) fn branch_id(&self, statement_id: StatementId) -> Option { + self.0[statement_id].branch } /// Returns an iterator over all [`StatementId`] ancestors, starting from the given [`StatementId`]. pub(crate) fn ancestor_ids(&self, id: StatementId) -> impl Iterator + '_ { - std::iter::successors(Some(id), |&id| self.statements[id].parent) + std::iter::successors(Some(id), |&id| self.0[id].parent) } } @@ -99,6 +67,6 @@ impl<'a> Index for Statements<'a> { #[inline] fn index(&self, index: StatementId) -> &Self::Output { - &self.statements[index].statement + &self.0[index].statement } }