diff --git a/crates/ruff_python_semantic/src/analyze/branch_detection.rs b/crates/ruff_python_semantic/src/analyze/branch_detection.rs index 97c6b2b9f9..5029d76edf 100644 --- a/crates/ruff_python_semantic/src/analyze/branch_detection.rs +++ b/crates/ruff_python_semantic/src/analyze/branch_detection.rs @@ -1,4 +1,3 @@ -use std::cmp::Ordering; use std::iter; use ruff_python_ast::{self as ast, ExceptHandler, Stmt}; @@ -12,32 +11,19 @@ fn common_ancestor( stop: Option, node_tree: &Statements, ) -> Option { - if stop.is_some_and(|stop| left == stop || right == stop) { - return None; - } - + // Fast path: if the nodes are the same, they are their own common ancestor. if left == right { return Some(left); } - let left_depth = node_tree.depth(left); - let right_depth = node_tree.depth(right); + // Grab all the ancestors of `right`. + let candidates = node_tree.ancestor_ids(right).collect::>(); - match left_depth.cmp(&right_depth) { - Ordering::Less => { - let right = node_tree.parent_id(right)?; - common_ancestor(left, right, stop, node_tree) - } - Ordering::Equal => { - let left = node_tree.parent_id(left)?; - let right = node_tree.parent_id(right)?; - common_ancestor(left, right, stop, node_tree) - } - Ordering::Greater => { - let left = node_tree.parent_id(left)?; - common_ancestor(left, right, stop, node_tree) - } - } + // Find the first ancestor of `left` that is also an ancestor of `right`. + node_tree + .ancestor_ids(left) + .take_while(|id| stop != Some(*id)) + .find(|id| candidates.contains(id)) } /// Return the alternative branches for a given node. diff --git a/crates/ruff_python_semantic/src/statements.rs b/crates/ruff_python_semantic/src/statements.rs index 8388273bd7..b0f2d75ccb 100644 --- a/crates/ruff_python_semantic/src/statements.rs +++ b/crates/ruff_python_semantic/src/statements.rs @@ -22,8 +22,6 @@ struct StatementWithParent<'a> { statement: &'a Stmt, /// The ID of the parent of this node, if any. parent: Option, - /// The depth of this node in the tree. - depth: u32, } /// The statements of a program indexed by [`StatementId`] @@ -46,11 +44,8 @@ impl<'a> Statements<'a> { if let Some(existing_id) = self.statement_to_id.insert(RefEquality(statement), next_id) { panic!("Statements already exists with ID: {existing_id:?}"); } - self.statements.push(StatementWithParent { - statement, - parent, - depth: parent.map_or(0, |parent| self.statements[parent].depth + 1), - }) + self.statements + .push(StatementWithParent { statement, parent }) } /// Returns the [`StatementId`] of the given statement. @@ -65,12 +60,6 @@ impl<'a> Statements<'a> { self.statements[statement_id].parent } - /// Return the depth of the statement. - #[inline] - pub(crate) fn depth(&self, id: StatementId) -> u32 { - self.statements[id].depth - } - /// Returns an iterator over all [`StatementId`] ancestors, starting from the given [`StatementId`]. pub(crate) fn ancestor_ids(&self, id: StatementId) -> impl Iterator + '_ { std::iter::successors(Some(id), |&id| self.statements[id].parent)