From c1f0661225b22aa39f8839648c8d283ea7c250e2 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sat, 6 May 2023 12:12:41 -0400 Subject: [PATCH] Replace `parents` statement stack with a `Nodes` abstraction (#4233) --- crates/ruff/src/checkers/ast/deferred.rs | 6 +- crates/ruff/src/checkers/ast/mod.rs | 141 ++++++++---------- crates/ruff/src/codes.rs | 3 +- .../rules/unused_loop_control_variable.rs | 7 +- .../rules/open_file_with_context_handler.rs | 4 +- .../rules/empty_type_checking_block.rs | 6 +- .../rules/pyflakes/rules/unused_variable.rs | 14 +- .../rules/pylint/rules/global_statement.rs | 4 +- .../pyupgrade/rules/outdated_version_block.rs | 6 +- .../rules/super_call_with_parameters.rs | 5 +- .../rules/unnecessary_builtin_import.rs | 7 +- .../rules/unnecessary_future_import.rs | 7 +- .../pyupgrade/rules/useless_metaclass_type.rs | 7 +- .../ruff_python_ast/src/branch_detection.rs | 113 -------------- crates/ruff_python_ast/src/lib.rs | 1 - crates/ruff_python_ast/src/types.rs | 12 ++ .../src/analyze/branch_detection.rs | 104 +++++++++++++ .../ruff_python_semantic/src/analyze/mod.rs | 1 + crates/ruff_python_semantic/src/binding.rs | 4 +- crates/ruff_python_semantic/src/context.rs | 76 +++++----- crates/ruff_python_semantic/src/lib.rs | 1 + crates/ruff_python_semantic/src/node.rs | 112 ++++++++++++++ 22 files changed, 362 insertions(+), 279 deletions(-) delete mode 100644 crates/ruff_python_ast/src/branch_detection.rs create mode 100644 crates/ruff_python_semantic/src/analyze/branch_detection.rs create mode 100644 crates/ruff_python_semantic/src/node.rs diff --git a/crates/ruff/src/checkers/ast/deferred.rs b/crates/ruff/src/checkers/ast/deferred.rs index 61f2135f9f..871b3fb92a 100644 --- a/crates/ruff/src/checkers/ast/deferred.rs +++ b/crates/ruff/src/checkers/ast/deferred.rs @@ -1,14 +1,16 @@ use ruff_text_size::TextRange; use rustpython_parser::ast::{Expr, Stmt}; -use ruff_python_ast::types::RefEquality; use ruff_python_semantic::analyze::visibility::{Visibility, VisibleScope}; +use ruff_python_semantic::node::NodeId; use ruff_python_semantic::scope::ScopeId; use crate::checkers::ast::AnnotationContext; use crate::docstrings::definition::Definition; -type Context<'a> = (ScopeId, Vec>); +/// A snapshot of the current scope and statement, which will be restored when visiting any +/// deferred definitions. +type Context<'a> = (ScopeId, Option); /// A collection of AST nodes that are deferred for later analysis. /// Used to, e.g., store functions, whose bodies shouldn't be analyzed until all diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index 14febab297..2a03890690 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -17,8 +17,9 @@ use ruff_python_ast::source_code::{Indexer, Locator, Stylist}; use ruff_python_ast::types::{Node, RefEquality}; use ruff_python_ast::typing::parse_type_annotation; use ruff_python_ast::visitor::{walk_excepthandler, walk_pattern, Visitor}; -use ruff_python_ast::{branch_detection, cast, helpers, str, visitor}; +use ruff_python_ast::{cast, helpers, str, visitor}; use ruff_python_semantic::analyze; +use ruff_python_semantic::analyze::branch_detection; use ruff_python_semantic::analyze::typing::{Callable, SubscriptKind}; use ruff_python_semantic::binding::{ Binding, BindingId, BindingKind, Exceptions, ExecutionContext, Export, FromImportation, @@ -175,7 +176,7 @@ where 'b: 'a, { fn visit_stmt(&mut self, stmt: &'b Stmt) { - self.ctx.push_parent(stmt); + self.ctx.push_stmt(stmt); // Track whether we've seen docstrings, non-imports, etc. match &stmt.node { @@ -196,7 +197,7 @@ where self.ctx.futures_allowed = false; if !self.ctx.seen_import_boundary && !helpers::is_assignment_to_a_dunder(stmt) - && !helpers::in_nested_block(self.ctx.parents.iter().rev().map(Into::into)) + && !helpers::in_nested_block(self.ctx.parents()) { self.ctx.seen_import_boundary = true; } @@ -230,7 +231,7 @@ where synthetic_usage: usage, typing_usage: None, range: *range, - source: Some(RefEquality(stmt)), + source: Some(stmt), context, exceptions, }); @@ -260,7 +261,7 @@ where synthetic_usage: usage, typing_usage: None, range: *range, - source: Some(RefEquality(stmt)), + source: Some(stmt), context, exceptions, }); @@ -303,10 +304,9 @@ where } StmtKind::Break => { if self.settings.rules.enabled(Rule::BreakOutsideLoop) { - if let Some(diagnostic) = pyflakes::rules::break_outside_loop( - stmt, - &mut self.ctx.parents.iter().rev().map(Into::into).skip(1), - ) { + if let Some(diagnostic) = + pyflakes::rules::break_outside_loop(stmt, &mut self.ctx.parents().skip(1)) + { self.diagnostics.push(diagnostic); } } @@ -315,7 +315,7 @@ where if self.settings.rules.enabled(Rule::ContinueOutsideLoop) { if let Some(diagnostic) = pyflakes::rules::continue_outside_loop( stmt, - &mut self.ctx.parents.iter().rev().map(Into::into).skip(1), + &mut self.ctx.parents().skip(1), ) { self.diagnostics.push(diagnostic); } @@ -688,7 +688,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -904,7 +904,7 @@ where synthetic_usage: Some((self.ctx.scope_id, alias.range())), typing_usage: None, range: alias.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -934,7 +934,7 @@ where synthetic_usage: None, typing_usage: None, range: alias.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -962,7 +962,7 @@ where }, typing_usage: None, range: alias.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -1222,7 +1222,7 @@ where synthetic_usage: Some((self.ctx.scope_id, alias.range())), typing_usage: None, range: alias.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -1318,7 +1318,7 @@ where }, typing_usage: None, range: alias.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -1714,7 +1714,7 @@ where if self.settings.rules.enabled(Rule::UnusedLoopControlVariable) { self.deferred .for_loops - .push((stmt, (self.ctx.scope_id, self.ctx.parents.clone()))); + .push((stmt, (self.ctx.scope_id, self.ctx.stmt_id))); } if self .settings @@ -2003,7 +2003,7 @@ where self.deferred.definitions.push(( definition, scope.visibility, - (self.ctx.scope_id, self.ctx.parents.clone()), + (self.ctx.scope_id, self.ctx.stmt_id), )); self.ctx.visible_scope = scope; @@ -2022,7 +2022,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(RefEquality(stmt)), + source: Some(stmt), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }); @@ -2041,7 +2041,7 @@ where self.deferred.functions.push(( stmt, - (self.ctx.scope_id, self.ctx.parents.clone()), + (self.ctx.scope_id, self.ctx.stmt_id), self.ctx.visible_scope, )); } @@ -2066,7 +2066,7 @@ where self.deferred.definitions.push(( definition, scope.visibility, - (self.ctx.scope_id, self.ctx.parents.clone()), + (self.ctx.scope_id, self.ctx.stmt_id), )); self.ctx.visible_scope = scope; @@ -2085,7 +2085,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(RefEquality(stmt)), + source: Some(stmt), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }); @@ -2246,7 +2246,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -2255,7 +2255,7 @@ where _ => {} } - self.ctx.pop_parent(); + self.ctx.pop_stmt(); } fn visit_annotation(&mut self, expr: &'b Expr) { @@ -2281,13 +2281,13 @@ where expr.range(), value, (self.ctx.in_annotation, self.ctx.in_type_checking_block), - (self.ctx.scope_id, self.ctx.parents.clone()), + (self.ctx.scope_id, self.ctx.stmt_id), )); } else { self.deferred.type_definitions.push(( expr, (self.ctx.in_annotation, self.ctx.in_type_checking_block), - (self.ctx.scope_id, self.ctx.parents.clone()), + (self.ctx.scope_id, self.ctx.stmt_id), )); } return; @@ -3514,7 +3514,7 @@ where expr.range(), value, (self.ctx.in_annotation, self.ctx.in_type_checking_block), - (self.ctx.scope_id, self.ctx.parents.clone()), + (self.ctx.scope_id, self.ctx.stmt_id), )); } if self @@ -3637,7 +3637,7 @@ where ExprKind::Lambda { .. } => { self.deferred .lambdas - .push((expr, (self.ctx.scope_id, self.ctx.parents.clone()))); + .push((expr, (self.ctx.scope_id, self.ctx.stmt_id))); } ExprKind::IfExp { test, body, orelse } => { visit_boolean_test!(self, test); @@ -4121,7 +4121,7 @@ where synthetic_usage: None, typing_usage: None, range: arg.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4165,7 +4165,7 @@ where synthetic_usage: None, typing_usage: None, range: pattern.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4212,12 +4212,7 @@ impl<'a> Checker<'a> { if !existing.kind.is_builtin() && existing.source.map_or(true, |left| { binding.source.map_or(true, |right| { - !branch_detection::different_forks( - left, - right, - &self.ctx.depths, - &self.ctx.child_to_parent, - ) + !branch_detection::different_forks(left, right, &self.ctx.stmts) }) }) { @@ -4517,7 +4512,7 @@ impl<'a> Checker<'a> { } fn handle_node_store(&mut self, id: &'a str, expr: &Expr) { - let parent = self.ctx.current_stmt().0; + let parent = self.ctx.current_stmt(); if self.settings.rules.enabled(Rule::UndefinedLocal) { pyflakes::rules::undefined_local(self, id); @@ -4576,7 +4571,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4596,7 +4591,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4613,7 +4608,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4696,7 +4691,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4718,7 +4713,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4734,7 +4729,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(*self.ctx.current_stmt()), + source: Some(self.ctx.current_stmt()), context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4745,7 +4740,7 @@ impl<'a> Checker<'a> { let ExprKind::Name { id, .. } = &expr.node else { return; }; - if helpers::on_conditional_branch(&mut self.ctx.parents.iter().rev().map(Into::into)) { + if helpers::on_conditional_branch(&mut self.ctx.parents()) { return; } @@ -4780,7 +4775,7 @@ impl<'a> Checker<'a> { docstring, }, self.ctx.visible_scope.visibility, - (self.ctx.scope_id, self.ctx.parents.clone()), + (self.ctx.scope_id, self.ctx.stmt_id), )); docstring.is_some() } @@ -4788,11 +4783,11 @@ impl<'a> Checker<'a> { fn check_deferred_type_definitions(&mut self) { while !self.deferred.type_definitions.is_empty() { let type_definitions = std::mem::take(&mut self.deferred.type_definitions); - for (expr, (in_annotation, in_type_checking_block), (scope_id, parents)) in + for (expr, (in_annotation, in_type_checking_block), (scope_id, node_id)) in type_definitions { self.ctx.scope_id = scope_id; - self.ctx.parents = parents; + self.ctx.stmt_id = node_id; self.ctx.in_annotation = in_annotation; self.ctx.in_type_checking_block = in_type_checking_block; self.ctx.in_type_definition = true; @@ -4807,7 +4802,7 @@ impl<'a> Checker<'a> { fn check_deferred_string_type_definitions(&mut self, allocator: &'a typed_arena::Arena) { while !self.deferred.string_type_definitions.is_empty() { let type_definitions = std::mem::take(&mut self.deferred.string_type_definitions); - for (range, value, (in_annotation, in_type_checking_block), (scope_id, parents)) in + for (range, value, (in_annotation, in_type_checking_block), (scope_id, node_id)) in type_definitions { if let Ok((expr, kind)) = parse_type_annotation(value, range, self.locator) { @@ -4825,7 +4820,7 @@ impl<'a> Checker<'a> { let expr = allocator.alloc(expr); self.ctx.scope_id = scope_id; - self.ctx.parents = parents; + self.ctx.stmt_id = node_id; self.ctx.in_annotation = in_annotation; self.ctx.in_type_checking_block = in_type_checking_block; self.ctx.in_type_definition = true; @@ -4854,10 +4849,9 @@ impl<'a> Checker<'a> { fn check_deferred_functions(&mut self) { while !self.deferred.functions.is_empty() { let deferred_functions = std::mem::take(&mut self.deferred.functions); - for (stmt, (scope_id, parents), visibility) in deferred_functions { - let parents_snapshot = parents.len(); + for (stmt, (scope_id, node_id), visibility) in deferred_functions { self.ctx.scope_id = scope_id; - self.ctx.parents = parents; + self.ctx.stmt_id = node_id; self.ctx.visible_scope = visibility; match &stmt.node { @@ -4871,10 +4865,7 @@ impl<'a> Checker<'a> { } } - let mut parents = std::mem::take(&mut self.ctx.parents); - parents.truncate(parents_snapshot); - - self.deferred.assignments.push((scope_id, parents)); + self.deferred.assignments.push((scope_id, node_id)); } } } @@ -4882,11 +4873,9 @@ impl<'a> Checker<'a> { fn check_deferred_lambdas(&mut self) { while !self.deferred.lambdas.is_empty() { let lambdas = std::mem::take(&mut self.deferred.lambdas); - for (expr, (scope_id, parents)) in lambdas { - let parents_snapshot = parents.len(); - + for (expr, (scope_id, node_id)) in lambdas { self.ctx.scope_id = scope_id; - self.ctx.parents = parents; + self.ctx.stmt_id = node_id; if let ExprKind::Lambda { args, body } = &expr.node { self.visit_arguments(args); @@ -4895,9 +4884,7 @@ impl<'a> Checker<'a> { unreachable!("Expected ExprKind::Lambda"); } - let mut parents = std::mem::take(&mut self.ctx.parents); - parents.truncate(parents_snapshot); - self.deferred.assignments.push((scope_id, parents)); + self.deferred.assignments.push((scope_id, node_id)); } } } @@ -4942,9 +4929,9 @@ impl<'a> Checker<'a> { while !self.deferred.for_loops.is_empty() { let for_loops = std::mem::take(&mut self.deferred.for_loops); - for (stmt, (scope_id, parents)) in for_loops { + for (stmt, (scope_id, node_id)) in for_loops { self.ctx.scope_id = scope_id; - self.ctx.parents = parents; + self.ctx.stmt_id = node_id; if let StmtKind::For { target, body, .. } | StmtKind::AsyncFor { target, body, .. } = &stmt.node @@ -5216,9 +5203,9 @@ impl<'a> Checker<'a> { // Collect all unused imports by location. (Multiple unused imports at the same // location indicates an `import from`.) type UnusedImport<'a> = (&'a str, &'a TextRange); - type BindingContext<'a, 'b> = ( - &'a RefEquality<'b, Stmt>, - Option<&'a RefEquality<'b, Stmt>>, + type BindingContext<'a> = ( + RefEquality<'a, Stmt>, + Option>, Exceptions, ); @@ -5245,10 +5232,9 @@ impl<'a> Checker<'a> { continue; } - let defined_by = binding.source.as_ref().unwrap(); - let defined_in = self.ctx.child_to_parent.get(defined_by); + let child = binding.source.unwrap(); + let parent = self.ctx.stmts.parent(child); let exceptions = binding.exceptions; - let child: &Stmt = defined_by.into(); let diagnostic_offset = binding.range.start(); let parent_offset = if matches!(child.node, StmtKind::ImportFrom { .. }) { @@ -5263,12 +5249,12 @@ impl<'a> Checker<'a> { }) { ignored - .entry((defined_by, defined_in, exceptions)) + .entry((RefEquality(child), parent.map(RefEquality), exceptions)) .or_default() .push((full_name, &binding.range)); } else { unused - .entry((defined_by, defined_in, exceptions)) + .entry((RefEquality(child), parent.map(RefEquality), exceptions)) .or_default() .push((full_name, &binding.range)); } @@ -5299,7 +5285,7 @@ impl<'a> Checker<'a> { ) { Ok(fix) => { if fix.is_deletion() || fix.content() == Some("pass") { - self.deletions.insert(*defined_by); + self.deletions.insert(defined_by); } Some(fix) } @@ -5336,11 +5322,10 @@ impl<'a> Checker<'a> { diagnostics.push(diagnostic); } } - for ((defined_by, .., exceptions), unused_imports) in ignored + for ((child, .., exceptions), unused_imports) in ignored .into_iter() .sorted_by_key(|((defined_by, ..), ..)| defined_by.start()) { - let child: &Stmt = defined_by.into(); let multiple = unused_imports.len() > 1; let in_except_handler = exceptions .intersects(Exceptions::MODULE_NOT_FOUND_ERROR | Exceptions::IMPORT_ERROR); @@ -5436,9 +5421,9 @@ impl<'a> Checker<'a> { let mut overloaded_name: Option = None; while !self.deferred.definitions.is_empty() { let definitions = std::mem::take(&mut self.deferred.definitions); - for (definition, visibility, (scope_id, parents)) in definitions { + for (definition, visibility, (scope_id, node_id)) in definitions { self.ctx.scope_id = scope_id; - self.ctx.parents = parents; + self.ctx.stmt_id = node_id; // flake8-annotations if enforce_annotations { diff --git a/crates/ruff/src/codes.rs b/crates/ruff/src/codes.rs index cffdf12abe..9ae775ba70 100644 --- a/crates/ruff/src/codes.rs +++ b/crates/ruff/src/codes.rs @@ -1,6 +1,7 @@ -use crate::registry::{Linter, Rule}; use std::fmt::Formatter; +use crate::registry::{Linter, Rule}; + #[derive(PartialEq, Eq, PartialOrd, Ord)] pub struct NoqaCode(&'static str, &'static str); diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs b/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs index 7e53161e5d..ad58504ecd 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs @@ -168,10 +168,9 @@ pub fn unused_loop_control_variable( let scope = checker.ctx.scope(); let binding = scope.bindings_for_name(name).find_map(|index| { let binding = &checker.ctx.bindings[*index]; - binding - .source - .as_ref() - .and_then(|source| (source == &RefEquality(stmt)).then_some(binding)) + binding.source.and_then(|source| { + (RefEquality(source) == RefEquality(stmt)).then_some(binding) + }) }); if let Some(binding) = binding { if binding.kind.is_loop_var() { diff --git a/crates/ruff/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs b/crates/ruff/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs index cc315cffe4..2e435dc6dc 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/open_file_with_context_handler.rs @@ -33,7 +33,7 @@ fn match_async_exit_stack(checker: &Checker) -> bool { if attr != "enter_async_context" { return false; } - for parent in &checker.ctx.parents { + for parent in checker.ctx.parents() { if let StmtKind::With { items, .. } = &parent.node { for item in items { if let ExprKind::Call { func, .. } = &item.context_expr.node { @@ -68,7 +68,7 @@ fn match_exit_stack(checker: &Checker) -> bool { if attr != "enter_context" { return false; } - for parent in &checker.ctx.parents { + for parent in checker.ctx.parents() { if let StmtKind::With { items, .. } = &parent.node { for item in items { if let ExprKind::Call { func, .. } = &item.context_expr.node { diff --git a/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs b/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs index 35d1d0aa45..0f01c066fc 100644 --- a/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs +++ b/crates/ruff/src/rules/flake8_type_checking/rules/empty_type_checking_block.rs @@ -60,11 +60,7 @@ pub fn empty_type_checking_block<'a, 'b>( // Delete the entire type-checking block. if checker.patch(diagnostic.kind.rule()) { - let parent = checker - .ctx - .child_to_parent - .get(&RefEquality(stmt)) - .map(Into::into); + let parent = checker.ctx.stmts.parent(stmt); let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect(); match delete_stmt( stmt, diff --git a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs index d7da2e2b69..e2ae12fb8d 100644 --- a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs +++ b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs @@ -214,11 +214,7 @@ fn remove_unused_variable( )) } else { // If (e.g.) assigning to a constant (`x = 1`), delete the entire statement. - let parent = checker - .ctx - .child_to_parent - .get(&RefEquality(stmt)) - .map(Into::into); + let parent = checker.ctx.stmts.parent(stmt); let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect(); match delete_stmt( stmt, @@ -259,11 +255,7 @@ fn remove_unused_variable( )) } else { // If assigning to a constant (`x = 1`), delete the entire statement. - let parent = checker - .ctx - .child_to_parent - .get(&RefEquality(stmt)) - .map(Into::into); + let parent = checker.ctx.stmts.parent(stmt); let deleted: Vec<&Stmt> = checker.deletions.iter().map(Into::into).collect(); match delete_stmt( stmt, @@ -336,7 +328,7 @@ pub fn unused_variable(checker: &mut Checker, scope: ScopeId) { binding.range, ); if checker.patch(diagnostic.kind.rule()) { - if let Some(stmt) = binding.source.as_ref().map(Into::into) { + if let Some(stmt) = binding.source { if let Some((kind, fix)) = remove_unused_variable(stmt, binding.range, checker) { if matches!(kind, DeletionKind::Whole) { diff --git a/crates/ruff/src/rules/pylint/rules/global_statement.rs b/crates/ruff/src/rules/pylint/rules/global_statement.rs index afb8aa843a..e1efbf26cb 100644 --- a/crates/ruff/src/rules/pylint/rules/global_statement.rs +++ b/crates/ruff/src/rules/pylint/rules/global_statement.rs @@ -57,9 +57,7 @@ pub fn global_statement(checker: &mut Checker, name: &str) { if binding.kind.is_global() { let source: &Stmt = binding .source - .as_ref() - .expect("`global` bindings should always have a `source`") - .into(); + .expect("`global` bindings should always have a `source`"); let diagnostic = Diagnostic::new( GlobalStatement { name: name.to_string(), diff --git a/crates/ruff/src/rules/pyupgrade/rules/outdated_version_block.rs b/crates/ruff/src/rules/pyupgrade/rules/outdated_version_block.rs index be064dec61..208cd1ce59 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/outdated_version_block.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/outdated_version_block.rs @@ -164,9 +164,9 @@ fn fix_py2_block( let defined_by = checker.ctx.current_stmt(); let defined_in = checker.ctx.current_stmt_parent(); return match delete_stmt( - defined_by.into(), + defined_by, if block.starter == Tok::If { - defined_in.map(Into::into) + defined_in } else { None }, @@ -176,7 +176,7 @@ fn fix_py2_block( checker.stylist, ) { Ok(fix) => { - checker.deletions.insert(RefEquality(defined_by.into())); + checker.deletions.insert(RefEquality(defined_by)); Some(fix) } Err(err) => { diff --git a/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs b/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs index b221ced002..6cfa877511 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs @@ -1,4 +1,4 @@ -use rustpython_parser::ast::{ArgData, Expr, ExprKind, Stmt, StmtKind}; +use rustpython_parser::ast::{ArgData, Expr, ExprKind, StmtKind}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_macros::{derive_message_formats, violation}; @@ -39,14 +39,13 @@ pub fn super_call_with_parameters(checker: &mut Checker, expr: &Expr, func: &Exp return; } let scope = checker.ctx.scope(); - let parents: Vec<&Stmt> = checker.ctx.parents.iter().map(Into::into).collect(); // Check: are we in a Function scope? if !matches!(scope.kind, ScopeKind::Function { .. }) { return; } - let mut parents = parents.iter().rev(); + let mut parents = checker.ctx.parents(); // For a `super` invocation to be unnecessary, the first argument needs to match // the enclosing class, and the second argument needs to match the first diff --git a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs index 6867c381a0..6ee57660bf 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_builtin_import.rs @@ -4,6 +4,7 @@ use rustpython_parser::ast::{Alias, AliasData, Located, Stmt}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::types::RefEquality; use crate::autofix; use crate::checkers::ast::Checker; @@ -113,8 +114,8 @@ pub fn unnecessary_builtin_import( .collect(); match autofix::actions::remove_unused_imports( unused_imports.iter().map(String::as_str), - defined_by.into(), - defined_in.map(Into::into), + defined_by, + defined_in, &deleted, checker.locator, checker.indexer, @@ -122,7 +123,7 @@ pub fn unnecessary_builtin_import( ) { Ok(fix) => { if fix.is_deletion() || fix.content() == Some("pass") { - checker.deletions.insert(*defined_by); + checker.deletions.insert(RefEquality(defined_by)); } diagnostic.set_fix(fix); } diff --git a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs index df31d7695a..a171b95832 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/unnecessary_future_import.rs @@ -4,6 +4,7 @@ use rustpython_parser::ast::{Alias, AliasData, Located, Stmt}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::types::RefEquality; use crate::autofix; use crate::checkers::ast::Checker; @@ -93,8 +94,8 @@ pub fn unnecessary_future_import(checker: &mut Checker, stmt: &Stmt, names: &[Lo .collect(); match autofix::actions::remove_unused_imports( unused_imports.iter().map(String::as_str), - defined_by.into(), - defined_in.map(Into::into), + defined_by, + defined_in, &deleted, checker.locator, checker.indexer, @@ -102,7 +103,7 @@ pub fn unnecessary_future_import(checker: &mut Checker, stmt: &Stmt, names: &[Lo ) { Ok(fix) => { if fix.is_deletion() || fix.content() == Some("pass") { - checker.deletions.insert(*defined_by); + checker.deletions.insert(RefEquality(defined_by)); } diagnostic.set_fix(fix); } diff --git a/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs b/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs index e1b2839828..d6640fb2a5 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/useless_metaclass_type.rs @@ -4,6 +4,7 @@ use rustpython_parser::ast::{Expr, ExprKind, Stmt}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::types::RefEquality; use crate::autofix::actions; use crate::checkers::ast::Checker; @@ -53,8 +54,8 @@ pub fn useless_metaclass_type(checker: &mut Checker, stmt: &Stmt, value: &Expr, let defined_by = checker.ctx.current_stmt(); let defined_in = checker.ctx.current_stmt_parent(); match actions::delete_stmt( - defined_by.into(), - defined_in.map(Into::into), + defined_by, + defined_in, &deleted, checker.locator, checker.indexer, @@ -62,7 +63,7 @@ pub fn useless_metaclass_type(checker: &mut Checker, stmt: &Stmt, value: &Expr, ) { Ok(fix) => { if fix.is_deletion() || fix.content() == Some("pass") { - checker.deletions.insert(*defined_by); + checker.deletions.insert(RefEquality(defined_by)); } diagnostic.set_fix(fix); } diff --git a/crates/ruff_python_ast/src/branch_detection.rs b/crates/ruff_python_ast/src/branch_detection.rs deleted file mode 100644 index 2e3b2ced28..0000000000 --- a/crates/ruff_python_ast/src/branch_detection.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::cmp::Ordering; - -use rustc_hash::FxHashMap; -use rustpython_parser::ast::ExcepthandlerKind::ExceptHandler; -use rustpython_parser::ast::{Stmt, StmtKind}; - -use crate::types::RefEquality; - -/// Return the common ancestor of `left` and `right` below `stop`, or `None`. -fn common_ancestor<'a>( - left: RefEquality<'a, Stmt>, - right: RefEquality<'a, Stmt>, - stop: Option>, - depths: &'a FxHashMap, usize>, - child_to_parent: &'a FxHashMap, RefEquality<'a, Stmt>>, -) -> Option> { - if Some(left) == stop || Some(right) == stop { - return None; - } - - if left == right { - return Some(left); - } - - let left_depth = depths.get(&left)?; - let right_depth = depths.get(&right)?; - match left_depth.cmp(right_depth) { - Ordering::Less => common_ancestor( - left, - *child_to_parent.get(&right)?, - stop, - depths, - child_to_parent, - ), - Ordering::Equal => common_ancestor( - *child_to_parent.get(&left)?, - *child_to_parent.get(&right)?, - stop, - depths, - child_to_parent, - ), - Ordering::Greater => common_ancestor( - *child_to_parent.get(&left)?, - right, - stop, - depths, - child_to_parent, - ), - } -} - -/// Return the alternative branches for a given node. -fn alternatives(stmt: RefEquality) -> Vec>> { - match &stmt.as_ref().node { - StmtKind::If { body, .. } => vec![body.iter().map(RefEquality).collect()], - StmtKind::Try { - body, - handlers, - orelse, - .. - } - | StmtKind::TryStar { - body, - handlers, - orelse, - .. - } => vec![body.iter().chain(orelse.iter()).map(RefEquality).collect()] - .into_iter() - .chain(handlers.iter().map(|handler| { - let ExceptHandler { body, .. } = &handler.node; - body.iter().map(RefEquality).collect() - })) - .collect(), - StmtKind::Match { cases, .. } => cases - .iter() - .map(|case| case.body.iter().map(RefEquality).collect()) - .collect(), - _ => vec![], - } -} - -/// Return `true` if `stmt` is a descendent of any of the nodes in `ancestors`. -fn descendant_of<'a>( - stmt: RefEquality<'a, Stmt>, - ancestors: &[RefEquality<'a, Stmt>], - stop: RefEquality<'a, Stmt>, - depths: &FxHashMap, usize>, - child_to_parent: &FxHashMap, RefEquality<'a, Stmt>>, -) -> bool { - ancestors.iter().any(|ancestor| { - common_ancestor(stmt, *ancestor, Some(stop), depths, child_to_parent).is_some() - }) -} - -/// Return `true` if `left` and `right` are on different branches of an `if` or -/// `try` statement. -pub fn different_forks<'a>( - left: RefEquality<'a, Stmt>, - right: RefEquality<'a, Stmt>, - depths: &FxHashMap, usize>, - child_to_parent: &FxHashMap, RefEquality<'a, Stmt>>, -) -> bool { - if let Some(ancestor) = common_ancestor(left, right, None, depths, child_to_parent) { - for items in alternatives(ancestor) { - let l = descendant_of(left, &items, ancestor, depths, child_to_parent); - let r = descendant_of(right, &items, ancestor, depths, child_to_parent); - if l ^ r { - return true; - } - } - } - false -} diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index 14afd50a66..f57158ff73 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -1,5 +1,4 @@ pub mod all; -pub mod branch_detection; pub mod call_path; pub mod cast; pub mod comparable; diff --git a/crates/ruff_python_ast/src/types.rs b/crates/ruff_python_ast/src/types.rs index 6ed092d684..7596cfae44 100644 --- a/crates/ruff_python_ast/src/types.rs +++ b/crates/ruff_python_ast/src/types.rs @@ -70,3 +70,15 @@ impl<'a> From<&RefEquality<'a, Expr>> for &'a Expr { r.0 } } + +impl<'a> From> for &'a Stmt { + fn from(r: RefEquality<'a, Stmt>) -> Self { + r.0 + } +} + +impl<'a> From> for &'a Expr { + fn from(r: RefEquality<'a, Expr>) -> Self { + r.0 + } +} diff --git a/crates/ruff_python_semantic/src/analyze/branch_detection.rs b/crates/ruff_python_semantic/src/analyze/branch_detection.rs new file mode 100644 index 0000000000..3a1499bea1 --- /dev/null +++ b/crates/ruff_python_semantic/src/analyze/branch_detection.rs @@ -0,0 +1,104 @@ +use std::cmp::Ordering; + +use ruff_python_ast::types::RefEquality; +use rustpython_parser::ast::ExcepthandlerKind::ExceptHandler; +use rustpython_parser::ast::{Stmt, StmtKind}; + +use crate::node::Nodes; + +/// Return the common ancestor of `left` and `right` below `stop`, or `None`. +fn common_ancestor<'a>( + left: &'a Stmt, + right: &'a Stmt, + stop: Option<&'a Stmt>, + node_tree: &Nodes<'a>, +) -> Option<&'a Stmt> { + if stop.map_or(false, |stop| { + RefEquality(left) == RefEquality(stop) || RefEquality(right) == RefEquality(stop) + }) { + return None; + } + + if RefEquality(left) == RefEquality(right) { + return Some(left); + } + + let left_id = node_tree.node_id(left)?; + let right_id = node_tree.node_id(right)?; + + let left_depth = node_tree.depth(left_id); + let right_depth = node_tree.depth(right_id); + + match left_depth.cmp(&right_depth) { + Ordering::Less => { + let right_id = node_tree.parent_id(right_id)?; + common_ancestor(left, node_tree[right_id], stop, node_tree) + } + Ordering::Equal => { + let left_id = node_tree.parent_id(left_id)?; + let right_id = node_tree.parent_id(right_id)?; + common_ancestor(node_tree[left_id], node_tree[right_id], stop, node_tree) + } + Ordering::Greater => { + let left_id = node_tree.parent_id(left_id)?; + common_ancestor(node_tree[left_id], right, stop, node_tree) + } + } +} + +/// Return the alternative branches for a given node. +fn alternatives(stmt: &Stmt) -> Vec> { + match &stmt.node { + StmtKind::If { body, .. } => vec![body.iter().collect()], + StmtKind::Try { + body, + handlers, + orelse, + .. + } + | StmtKind::TryStar { + body, + handlers, + orelse, + .. + } => vec![body.iter().chain(orelse.iter()).collect()] + .into_iter() + .chain(handlers.iter().map(|handler| { + let ExceptHandler { body, .. } = &handler.node; + body.iter().collect() + })) + .collect(), + StmtKind::Match { cases, .. } => cases + .iter() + .map(|case| case.body.iter().collect()) + .collect(), + _ => vec![], + } +} + +/// Return `true` if `stmt` is a descendent of any of the nodes in `ancestors`. +fn descendant_of<'a>( + stmt: &'a Stmt, + ancestors: &[&'a Stmt], + stop: &'a Stmt, + node_tree: &Nodes<'a>, +) -> bool { + ancestors + .iter() + .any(|ancestor| common_ancestor(stmt, ancestor, Some(stop), node_tree).is_some()) +} + +/// Return `true` if `left` and `right` are on different branches of an `if` or +/// `try` statement. +pub fn different_forks<'a>(left: &'a Stmt, right: &'a Stmt, node_tree: &Nodes<'a>) -> bool { + if let Some(ancestor) = common_ancestor(left, right, None, node_tree) { + for items in alternatives(ancestor) { + let l = descendant_of(left, &items, ancestor, node_tree); + let r = descendant_of(right, &items, ancestor, node_tree); + if l ^ r { + return true; + } + } + } + false +} diff --git a/crates/ruff_python_semantic/src/analyze/mod.rs b/crates/ruff_python_semantic/src/analyze/mod.rs index 7298cbc87a..a4cd2fdf50 100644 --- a/crates/ruff_python_semantic/src/analyze/mod.rs +++ b/crates/ruff_python_semantic/src/analyze/mod.rs @@ -1,3 +1,4 @@ +pub mod branch_detection; pub mod function_type; pub mod logging; pub mod typing; diff --git a/crates/ruff_python_semantic/src/binding.rs b/crates/ruff_python_semantic/src/binding.rs index 9c50cfd132..6628fa235b 100644 --- a/crates/ruff_python_semantic/src/binding.rs +++ b/crates/ruff_python_semantic/src/binding.rs @@ -5,8 +5,6 @@ use bitflags::bitflags; use ruff_text_size::TextRange; use rustpython_parser::ast::Stmt; -use ruff_python_ast::types::RefEquality; - use crate::scope::ScopeId; #[derive(Debug, Clone)] @@ -16,7 +14,7 @@ pub struct Binding<'a> { /// The context in which the binding was created. pub context: ExecutionContext, /// The statement in which the [`Binding`] was defined. - pub source: Option>, + pub source: Option<&'a Stmt>, /// Tuple of (scope index, range) indicating the scope and range at which /// the binding was last used in a runtime context. pub runtime_usage: Option<(ScopeId, TextRange)>, diff --git a/crates/ruff_python_semantic/src/context.rs b/crates/ruff_python_semantic/src/context.rs index 661bd7a88b..bdd10a5146 100644 --- a/crates/ruff_python_semantic/src/context.rs +++ b/crates/ruff_python_semantic/src/context.rs @@ -1,7 +1,6 @@ use std::path::Path; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use rustc_hash::FxHashMap; use rustpython_parser::ast::{Expr, Stmt}; use smallvec::smallvec; @@ -17,26 +16,26 @@ use crate::binding::{ Binding, BindingId, BindingKind, Bindings, Exceptions, ExecutionContext, FromImportation, Importation, SubmoduleImportation, }; +use crate::node::{NodeId, Nodes}; use crate::scope::{Scope, ScopeId, ScopeKind, Scopes}; #[allow(clippy::struct_excessive_bools)] pub struct Context<'a> { pub typing_modules: &'a [String], pub module_path: Option>, - // Retain all scopes and parent nodes, along with a stack of indices to track which are active - // at various points in time. - pub parents: Vec>, - pub depths: FxHashMap, usize>, - pub child_to_parent: FxHashMap, RefEquality<'a, Stmt>>, + // Stack of all visited statements, along with the identifier of the current statement. + pub stmts: Nodes<'a>, + pub stmt_id: Option, + // Stack of all scopes, along with the identifier of the current scope. + pub scopes: Scopes<'a>, + pub scope_id: ScopeId, + pub dead_scopes: Vec, // A stack of all bindings created in any scope, at any point in execution. pub bindings: Bindings<'a>, // Map from binding index to indexes of bindings that shadow it in other scopes. pub shadowed_bindings: std::collections::HashMap, BuildNoHashHasher>, pub exprs: Vec>, - pub scopes: Scopes<'a>, - pub scope_id: ScopeId, - pub dead_scopes: Vec, // Body iteration; used to peek at siblings. pub body: &'a [Stmt], pub body_index: usize, @@ -68,15 +67,14 @@ impl<'a> Context<'a> { Self { typing_modules, module_path, - parents: Vec::default(), - depths: FxHashMap::default(), - child_to_parent: FxHashMap::default(), - bindings: Bindings::default(), - shadowed_bindings: IntMap::default(), - exprs: Vec::default(), + stmts: Nodes::default(), + stmt_id: None, scopes: Scopes::default(), scope_id: ScopeId::global(), dead_scopes: Vec::default(), + bindings: Bindings::default(), + shadowed_bindings: IntMap::default(), + exprs: Vec::default(), body: &[], body_index: 0, visible_scope: VisibleScope { @@ -254,10 +252,7 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some(( - binding.source.as_ref().unwrap().into(), - format!("{name}.{member}"), - )); + return Some((binding.source.unwrap(), format!("{name}.{member}"))); } } } @@ -273,10 +268,7 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some(( - binding.source.as_ref().unwrap().into(), - (*name).to_string(), - )); + return Some((binding.source.unwrap(), (*name).to_string())); } } } @@ -291,10 +283,7 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some(( - binding.source.as_ref().unwrap().into(), - format!("{name}.{member}"), - )); + return Some((binding.source.unwrap(), format!("{name}.{member}"))); } } } @@ -306,18 +295,15 @@ impl<'a> Context<'a> { }) } - pub fn push_parent(&mut self, parent: &'a Stmt) { - let num_existing = self.parents.len(); - self.parents.push(RefEquality(parent)); - self.depths.insert(self.parents[num_existing], num_existing); - if num_existing > 0 { - self.child_to_parent - .insert(self.parents[num_existing], self.parents[num_existing - 1]); - } + /// Push a [`Stmt`] onto the stack. + pub fn push_stmt(&mut self, stmt: &'a Stmt) { + self.stmt_id = Some(self.stmts.insert(stmt, self.stmt_id)); } - pub fn pop_parent(&mut self) { - self.parents.pop().expect("Attempted to pop without parent"); + /// Pop the current [`Stmt`] off the stack. + pub fn pop_stmt(&mut self) { + let node_id = self.stmt_id.expect("Attempted to pop without statement"); + self.stmt_id = self.stmts.parent_id(node_id); } pub fn push_expr(&mut self, expr: &'a Expr) { @@ -345,13 +331,16 @@ impl<'a> Context<'a> { } /// Return the current `Stmt`. - pub fn current_stmt(&self) -> &RefEquality<'a, Stmt> { - self.parents.iter().rev().next().expect("No parent found") + pub fn current_stmt(&self) -> &'a Stmt { + let node_id = self.stmt_id.expect("No current statement"); + self.stmts[node_id] } /// Return the parent `Stmt` of the current `Stmt`, if any. - pub fn current_stmt_parent(&self) -> Option<&RefEquality<'a, Stmt>> { - self.parents.iter().rev().nth(1) + pub fn current_stmt_parent(&self) -> Option<&'a Stmt> { + let node_id = self.stmt_id.expect("No current statement"); + let parent_id = self.stmts.parent_id(node_id)?; + Some(self.stmts[parent_id]) } /// Return the parent `Expr` of the current `Expr`. @@ -399,6 +388,11 @@ impl<'a> Context<'a> { self.scopes.ancestors(self.scope_id) } + pub fn parents(&self) -> impl Iterator + '_ { + let node_id = self.stmt_id.expect("No current statement"); + self.stmts.ancestor_ids(node_id).map(|id| self.stmts[id]) + } + /// Returns `true` if the context is in an exception handler. pub const fn in_exception_handler(&self) -> bool { self.in_exception_handler diff --git a/crates/ruff_python_semantic/src/lib.rs b/crates/ruff_python_semantic/src/lib.rs index f6c8f3fd46..e85172a6a1 100644 --- a/crates/ruff_python_semantic/src/lib.rs +++ b/crates/ruff_python_semantic/src/lib.rs @@ -1,4 +1,5 @@ pub mod analyze; pub mod binding; pub mod context; +pub mod node; pub mod scope; diff --git a/crates/ruff_python_semantic/src/node.rs b/crates/ruff_python_semantic/src/node.rs new file mode 100644 index 0000000000..7d84e03461 --- /dev/null +++ b/crates/ruff_python_semantic/src/node.rs @@ -0,0 +1,112 @@ +use std::num::{NonZeroU32, TryFromIntError}; +use std::ops::{Index, IndexMut}; + +use rustc_hash::FxHashMap; +use rustpython_parser::ast::Stmt; + +use ruff_python_ast::types::RefEquality; + +/// Id uniquely identifying a statement in a program. +/// +/// Using a `u32` is sufficient because Ruff only supports parsing documents with a size of max `u32::max` +/// and it is impossible to have more statements than characters in the file. We use a `NonZeroU32` to +/// take advantage of memory layout optimizations. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub struct NodeId(NonZeroU32); + +/// Convert a `usize` to a `NodeId` (by adding 1 to the value, and casting to `NonZeroU32`). +impl TryFrom for NodeId { + type Error = TryFromIntError; + + fn try_from(value: usize) -> Result { + Ok(Self(NonZeroU32::try_from(u32::try_from(value)? + 1)?)) + } +} + +/// Convert a `NodeId` to a `usize` (by subtracting 1 from the value, and casting to `usize`). +impl From for usize { + fn from(value: NodeId) -> Self { + value.0.get() as usize - 1 + } +} + +#[derive(Debug)] +struct Node<'a> { + /// The statement this node represents. + stmt: &'a Stmt, + /// The ID of the parent of this node, if any. + parent: Option, + /// The depth of this node in the tree. + depth: u32, +} + +/// The nodes of a program indexed by [`NodeId`] +#[derive(Debug, Default)] +pub struct Nodes<'a> { + nodes: Vec>, + node_to_id: FxHashMap, NodeId>, +} + +impl<'a> Nodes<'a> { + /// Inserts a new node into the node tree and returns its unique id. + /// + /// Panics if a node with the same pointer already exists. + pub fn insert(&mut self, stmt: &'a Stmt, parent: Option) -> NodeId { + let next_id = NodeId::try_from(self.nodes.len()).unwrap(); + if let Some(existing_id) = self.node_to_id.insert(RefEquality(stmt), next_id) { + panic!("Node already exists with id {existing_id:?}"); + } + self.nodes.push(Node { + stmt, + parent, + depth: parent.map_or(0, |parent| self.nodes[usize::from(parent)].depth + 1), + }); + next_id + } + + /// Returns the [`NodeId`] of the given node. + #[inline] + pub fn node_id(&self, node: &'a Stmt) -> Option { + self.node_to_id.get(&RefEquality(node)).copied() + } + + /// Return the [`NodeId`] of the parent node. + #[inline] + pub fn parent_id(&self, node_id: NodeId) -> Option { + self.nodes[usize::from(node_id)].parent + } + + /// Return the depth of the node. + #[inline] + pub fn depth(&self, node_id: NodeId) -> u32 { + self.nodes[usize::from(node_id)].depth + } + + /// Returns an iterator over all [`NodeId`] ancestors, starting from the given [`NodeId`]. + pub fn ancestor_ids(&self, node_id: NodeId) -> impl Iterator + '_ { + std::iter::successors(Some(node_id), |&node_id| { + self.nodes[usize::from(node_id)].parent + }) + } + + /// Return the parent of the given node. + pub fn parent(&self, node: &'a Stmt) -> Option<&'a Stmt> { + let node_id = self.node_to_id.get(&RefEquality(node))?; + let parent_id = self.nodes[usize::from(*node_id)].parent?; + Some(self[parent_id]) + } +} + +impl<'a> Index for Nodes<'a> { + type Output = &'a Stmt; + + fn index(&self, index: NodeId) -> &Self::Output { + &self.nodes[usize::from(index)].stmt + } +} + +impl<'a> IndexMut for Nodes<'a> { + fn index_mut(&mut self, index: NodeId) -> &mut Self::Output { + &mut self.nodes[usize::from(index)].stmt + } +}