diff --git a/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs b/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs index 9922f2411e..e9f28c852d 100644 --- a/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs +++ b/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs @@ -168,7 +168,7 @@ pub(crate) fn deferred_scopes(checker: &mut Checker) { continue; } - let Some(statement_id) = shadowed.source else { + let Some(node_id) = shadowed.source else { continue; }; @@ -176,7 +176,7 @@ pub(crate) fn deferred_scopes(checker: &mut Checker) { if shadowed.kind.is_function_definition() { if checker .semantic - .statement(statement_id) + .statement(node_id) .as_function_def_stmt() .is_some_and(|function| { visibility::is_overload( diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index 8d86e48867..ef7b794783 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -267,7 +267,7 @@ where { fn visit_stmt(&mut self, stmt: &'b Stmt) { // Step 0: Pre-processing - self.semantic.push_statement(stmt); + self.semantic.push_node(stmt); // Track whether we've seen docstrings, non-imports, etc. match stmt { @@ -779,7 +779,7 @@ where analyze::statement(stmt, self); self.semantic.flags = flags_snapshot; - self.semantic.pop_statement(); + self.semantic.pop_node(); } fn visit_annotation(&mut self, expr: &'b Expr) { @@ -815,7 +815,7 @@ where return; } - self.semantic.push_expression(expr); + self.semantic.push_node(expr); // Store the flags prior to any further descent, so that we can restore them after visiting // the node. @@ -1235,7 +1235,7 @@ where analyze::expression(expr, self); self.semantic.flags = flags_snapshot; - self.semantic.pop_expression(); + self.semantic.pop_node(); } fn visit_except_handler(&mut self, except_handler: &'b ExceptHandler) { diff --git a/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs b/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs index 28be0841c1..9fb04c5cee 100644 --- a/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs +++ b/crates/ruff/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs @@ -6,7 +6,7 @@ use rustc_hash::FxHashMap; use ruff_diagnostics::{AutofixKind, Diagnostic, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::Ranged; -use ruff_python_semantic::{AnyImport, Imported, ResolvedReferenceId, Scope, StatementId}; +use ruff_python_semantic::{AnyImport, Imported, NodeId, ResolvedReferenceId, Scope}; use ruff_text_size::TextRange; use crate::autofix; @@ -72,8 +72,8 @@ pub(crate) fn runtime_import_in_type_checking_block( diagnostics: &mut Vec, ) { // Collect all runtime imports by statement. - let mut errors_by_statement: FxHashMap> = FxHashMap::default(); - let mut ignores_by_statement: FxHashMap> = FxHashMap::default(); + let mut errors_by_statement: FxHashMap> = FxHashMap::default(); + let mut ignores_by_statement: FxHashMap> = FxHashMap::default(); for binding_id in scope.binding_ids() { let binding = checker.semantic().binding(binding_id); @@ -95,7 +95,7 @@ pub(crate) fn runtime_import_in_type_checking_block( .is_runtime() }) { - let Some(statement_id) = binding.source else { + let Some(node_id) = binding.source else { continue; }; @@ -115,23 +115,20 @@ pub(crate) fn runtime_import_in_type_checking_block( }) { ignores_by_statement - .entry(statement_id) + .entry(node_id) .or_default() .push(import); } else { - errors_by_statement - .entry(statement_id) - .or_default() - .push(import); + errors_by_statement.entry(node_id).or_default().push(import); } } } // Generate a diagnostic for every import, but share a fix across all imports within the same // statement (excluding those that are ignored). - for (statement_id, imports) in errors_by_statement { + for (node_id, imports) in errors_by_statement { let fix = if checker.patch(Rule::RuntimeImportInTypeCheckingBlock) { - fix_imports(checker, statement_id, &imports).ok() + fix_imports(checker, node_id, &imports).ok() } else { None }; @@ -200,13 +197,9 @@ impl Ranged for ImportBinding<'_> { } /// Generate a [`Fix`] to remove runtime imports from a type-checking block. -fn fix_imports( - checker: &Checker, - statement_id: StatementId, - imports: &[ImportBinding], -) -> Result { - let statement = checker.semantic().statement(statement_id); - let parent = checker.semantic().parent_statement(statement_id); +fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> Result { + let statement = checker.semantic().statement(node_id); + let parent = checker.semantic().parent_statement(node_id); let member_names: Vec> = imports .iter() diff --git a/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs b/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs index c1d282d3ce..150a07e3d2 100644 --- a/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs +++ b/crates/ruff/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs @@ -6,7 +6,7 @@ use rustc_hash::FxHashMap; use ruff_diagnostics::{AutofixKind, Diagnostic, DiagnosticKind, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::Ranged; -use ruff_python_semantic::{AnyImport, Binding, Imported, ResolvedReferenceId, Scope, StatementId}; +use ruff_python_semantic::{AnyImport, Binding, Imported, NodeId, ResolvedReferenceId, Scope}; use ruff_text_size::TextRange; use crate::autofix; @@ -227,9 +227,9 @@ pub(crate) fn typing_only_runtime_import( diagnostics: &mut Vec, ) { // Collect all typing-only imports by statement and import type. - let mut errors_by_statement: FxHashMap<(StatementId, ImportType), Vec> = + let mut errors_by_statement: FxHashMap<(NodeId, ImportType), Vec> = FxHashMap::default(); - let mut ignores_by_statement: FxHashMap<(StatementId, ImportType), Vec> = + let mut ignores_by_statement: FxHashMap<(NodeId, ImportType), Vec> = FxHashMap::default(); for binding_id in scope.binding_ids() { @@ -302,7 +302,7 @@ pub(crate) fn typing_only_runtime_import( continue; } - let Some(statement_id) = binding.source else { + let Some(node_id) = binding.source else { continue; }; @@ -319,12 +319,12 @@ pub(crate) fn typing_only_runtime_import( }) { ignores_by_statement - .entry((statement_id, import_type)) + .entry((node_id, import_type)) .or_default() .push(import); } else { errors_by_statement - .entry((statement_id, import_type)) + .entry((node_id, import_type)) .or_default() .push(import); } @@ -333,9 +333,9 @@ pub(crate) fn typing_only_runtime_import( // Generate a diagnostic for every import, but share a fix across all imports within the same // statement (excluding those that are ignored). - for ((statement_id, import_type), imports) in errors_by_statement { + for ((node_id, import_type), imports) in errors_by_statement { let fix = if checker.patch(rule_for(import_type)) { - fix_imports(checker, statement_id, &imports).ok() + fix_imports(checker, node_id, &imports).ok() } else { None }; @@ -445,13 +445,9 @@ fn is_exempt(name: &str, exempt_modules: &[&str]) -> bool { } /// Generate a [`Fix`] to remove typing-only imports from a runtime context. -fn fix_imports( - checker: &Checker, - statement_id: StatementId, - imports: &[ImportBinding], -) -> Result { - let statement = checker.semantic().statement(statement_id); - let parent = checker.semantic().parent_statement(statement_id); +fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> Result { + let statement = checker.semantic().statement(node_id); + let parent = checker.semantic().parent_statement(node_id); let member_names: Vec> = imports .iter() diff --git a/crates/ruff/src/rules/pyflakes/rules/unused_import.rs b/crates/ruff/src/rules/pyflakes/rules/unused_import.rs index 94af0eb2db..0ec33b2ad8 100644 --- a/crates/ruff/src/rules/pyflakes/rules/unused_import.rs +++ b/crates/ruff/src/rules/pyflakes/rules/unused_import.rs @@ -6,7 +6,7 @@ use rustc_hash::FxHashMap; use ruff_diagnostics::{AutofixKind, Diagnostic, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::Ranged; -use ruff_python_semantic::{AnyImport, Exceptions, Imported, Scope, StatementId}; +use ruff_python_semantic::{AnyImport, Exceptions, Imported, NodeId, Scope}; use ruff_text_size::TextRange; use crate::autofix; @@ -100,9 +100,8 @@ impl Violation for UnusedImport { pub(crate) fn unused_import(checker: &Checker, scope: &Scope, diagnostics: &mut Vec) { // Collect all unused imports by statement. - let mut unused: FxHashMap<(StatementId, Exceptions), Vec> = FxHashMap::default(); - let mut ignored: FxHashMap<(StatementId, Exceptions), Vec> = - FxHashMap::default(); + let mut unused: FxHashMap<(NodeId, Exceptions), Vec> = FxHashMap::default(); + let mut ignored: FxHashMap<(NodeId, Exceptions), Vec> = FxHashMap::default(); for binding_id in scope.binding_ids() { let binding = checker.semantic().binding(binding_id); @@ -119,7 +118,7 @@ pub(crate) fn unused_import(checker: &Checker, scope: &Scope, diagnostics: &mut continue; }; - let Some(statement_id) = binding.source else { + let Some(node_id) = binding.source else { continue; }; @@ -135,12 +134,12 @@ pub(crate) fn unused_import(checker: &Checker, scope: &Scope, diagnostics: &mut }) { ignored - .entry((statement_id, binding.exceptions)) + .entry((node_id, binding.exceptions)) .or_default() .push(import); } else { unused - .entry((statement_id, binding.exceptions)) + .entry((node_id, binding.exceptions)) .or_default() .push(import); } @@ -151,13 +150,13 @@ pub(crate) fn unused_import(checker: &Checker, scope: &Scope, diagnostics: &mut // Generate a diagnostic for every import, but share a fix across all imports within the same // statement (excluding those that are ignored). - for ((statement_id, exceptions), imports) in unused { + for ((node_id, exceptions), imports) in unused { let in_except_handler = exceptions.intersects(Exceptions::MODULE_NOT_FOUND_ERROR | Exceptions::IMPORT_ERROR); let multiple = imports.len() > 1; let fix = if !in_init && !in_except_handler && checker.patch(Rule::UnusedImport) { - fix_imports(checker, statement_id, &imports).ok() + fix_imports(checker, node_id, &imports).ok() } else { None }; @@ -234,13 +233,9 @@ impl Ranged for ImportBinding<'_> { } /// Generate a [`Fix`] to remove unused imports from a statement. -fn fix_imports( - checker: &Checker, - statement_id: StatementId, - imports: &[ImportBinding], -) -> Result { - let statement = checker.semantic().statement(statement_id); - let parent = checker.semantic().parent_statement(statement_id); +fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> Result { + let statement = checker.semantic().statement(node_id); + let parent = checker.semantic().parent_statement(node_id); let member_names: Vec> = imports .iter() diff --git a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs index f973f237e2..024aa0eb0f 100644 --- a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs +++ b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs @@ -203,12 +203,12 @@ where /// Generate a [`Edit`] to remove an unused variable assignment to a [`Binding`]. fn remove_unused_variable(binding: &Binding, checker: &Checker) -> Option { - let statement_id = binding.source?; - let statement = checker.semantic().statement(statement_id); - let parent = checker.semantic().parent_statement(statement_id); + let node_id = binding.source?; + let statement = checker.semantic().statement(node_id); + let parent = checker.semantic().parent_statement(node_id); let isolation = checker .semantic() - .parent_statement_id(statement_id) + .parent_statement_id(node_id) .map(|node_id| IsolationLevel::Group(node_id.into())) .unwrap_or_default(); diff --git a/crates/ruff_python_semantic/src/binding.rs b/crates/ruff_python_semantic/src/binding.rs index ca4c4d5327..ea4113e6c8 100644 --- a/crates/ruff_python_semantic/src/binding.rs +++ b/crates/ruff_python_semantic/src/binding.rs @@ -11,8 +11,8 @@ use ruff_text_size::TextRange; use crate::context::ExecutionContext; use crate::model::SemanticModel; +use crate::nodes::NodeId; use crate::reference::ResolvedReferenceId; -use crate::statements::StatementId; use crate::ScopeId; #[derive(Debug, Clone)] @@ -24,7 +24,7 @@ pub struct Binding<'a> { /// The context in which the [`Binding`] was created. pub context: ExecutionContext, /// The statement in which the [`Binding`] was defined. - pub source: Option, + pub source: Option, /// The references to the [`Binding`]. pub references: Vec, /// The exceptions that were handled when the [`Binding`] was defined. @@ -185,7 +185,7 @@ impl<'a> Binding<'a> { /// Returns the range of the binding's parent. pub fn parent_range(&self, semantic: &SemanticModel) -> Option { self.source - .map(|statement_id| semantic.statement(statement_id)) + .map(|id| semantic.statement(id)) .and_then(|parent| { if parent.is_import_from_stmt() { Some(parent.range()) diff --git a/crates/ruff_python_semantic/src/expressions.rs b/crates/ruff_python_semantic/src/expressions.rs deleted file mode 100644 index 47a012013c..0000000000 --- a/crates/ruff_python_semantic/src/expressions.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::ops::Index; - -use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast::Expr; - -/// Id uniquely identifying an expression in a program. -/// -/// Using a `u32` is sufficient because Ruff only supports parsing documents with a size of max -/// `u32::max` and it is impossible to have more nodes than characters in the file. We use a -/// `NonZeroU32` to take advantage of memory layout optimizations. -#[newtype_index] -#[derive(Ord, PartialOrd)] -pub struct ExpressionId; - -/// An [`Expr`] AST node in a program, along with a pointer to its parent expression (if any). -#[derive(Debug)] -struct ExpressionWithParent<'a> { - /// A pointer to the AST node. - node: &'a Expr, - /// The ID of the parent of this node, if any. - parent: Option, -} - -/// The nodes of a program indexed by [`ExpressionId`] -#[derive(Debug, Default)] -pub struct Expressions<'a>(IndexVec>); - -impl<'a> Expressions<'a> { - /// Inserts a new expression into the node tree and returns its unique id. - pub(crate) fn insert(&mut self, node: &'a Expr, parent: Option) -> ExpressionId { - self.0.push(ExpressionWithParent { node, parent }) - } - - /// Return the [`ExpressionId`] of the parent node. - #[inline] - pub fn parent_id(&self, node_id: ExpressionId) -> Option { - self.0[node_id].parent - } - - /// Returns an iterator over all [`ExpressionId`] ancestors, starting from the given [`ExpressionId`]. - pub(crate) fn ancestor_ids( - &self, - node_id: ExpressionId, - ) -> impl Iterator + '_ { - std::iter::successors(Some(node_id), |&node_id| self.0[node_id].parent) - } -} - -impl<'a> Index for Expressions<'a> { - type Output = &'a Expr; - - #[inline] - fn index(&self, index: ExpressionId) -> &Self::Output { - &self.0[index].node - } -} diff --git a/crates/ruff_python_semantic/src/lib.rs b/crates/ruff_python_semantic/src/lib.rs index 6b75e56430..ce45050239 100644 --- a/crates/ruff_python_semantic/src/lib.rs +++ b/crates/ruff_python_semantic/src/lib.rs @@ -3,22 +3,20 @@ mod binding; mod branches; mod context; mod definition; -mod expressions; mod globals; mod model; +mod nodes; mod reference; mod scope; mod star_import; -mod statements; pub use binding::*; pub use branches::*; pub use context::*; pub use definition::*; -pub use expressions::*; pub use globals::*; pub use model::*; +pub use nodes::*; pub use reference::*; pub use scope::*; pub use star_import::*; -pub use statements::*; diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index d40ac8ed56..ed01d5382d 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -18,14 +18,13 @@ use crate::binding::{ use crate::branches::{BranchId, Branches}; use crate::context::ExecutionContext; use crate::definition::{Definition, DefinitionId, Definitions, Member, Module}; -use crate::expressions::{ExpressionId, Expressions}; use crate::globals::{Globals, GlobalsArena}; +use crate::nodes::{NodeId, NodeRef, Nodes}; use crate::reference::{ ResolvedReference, ResolvedReferenceId, ResolvedReferences, UnresolvedReference, UnresolvedReferenceFlags, UnresolvedReferences, }; use crate::scope::{Scope, ScopeId, ScopeKind, Scopes}; -use crate::statements::{StatementId, Statements}; use crate::Imported; /// A semantic model for a Python module, to enable querying the module's semantic information. @@ -33,17 +32,11 @@ pub struct SemanticModel<'a> { typing_modules: &'a [String], module_path: Option<&'a [String]>, - /// Stack of statements in the program. - statements: Statements<'a>, + /// Stack of all AST nodes in the program. + nodes: Nodes<'a>, - /// The ID of the current statement. - statement_id: Option, - - /// Stack of expressions in the program. - expressions: Expressions<'a>, - - /// The ID of the current expression. - expression_id: Option, + /// The ID of the current AST node. + node_id: Option, /// Stack of all branches in the program. branches: Branches, @@ -141,12 +134,10 @@ impl<'a> SemanticModel<'a> { Self { typing_modules, module_path: module.path(), - statements: Statements::default(), - statement_id: None, - expressions: Expressions::default(), - expression_id: None, - branch_id: None, + nodes: Nodes::default(), + node_id: None, branches: Branches::default(), + branch_id: None, scopes: Scopes::default(), scope_id: ScopeId::global(), definitions: Definitions::for_module(module), @@ -236,7 +227,7 @@ impl<'a> SemanticModel<'a> { flags, references: Vec::new(), scope: self.scope_id, - source: self.statement_id, + source: self.node_id, context: self.execution_context(), exceptions: self.exceptions(), }) @@ -728,7 +719,7 @@ impl<'a> SemanticModel<'a> { { return Some(ImportedName { name: format!("{name}.{member}"), - range: self.statements[source].range(), + range: self.nodes[source].range(), context: binding.context, }); } @@ -752,7 +743,7 @@ impl<'a> SemanticModel<'a> { { return Some(ImportedName { name: (*name).to_string(), - range: self.statements[source].range(), + range: self.nodes[source].range(), context: binding.context, }); } @@ -773,7 +764,7 @@ impl<'a> SemanticModel<'a> { { return Some(ImportedName { name: format!("{name}.{member}"), - range: self.statements[source].range(), + range: self.nodes[source].range(), context: binding.context, }); } @@ -788,33 +779,15 @@ impl<'a> SemanticModel<'a> { }) } - /// Push a [`Stmt`] onto the stack. - pub fn push_statement(&mut self, stmt: &'a Stmt) { - self.statement_id = Some( - self.statements - .insert(stmt, self.statement_id, self.branch_id), - ); + /// Push an AST node [`NodeRef`] onto the stack. + pub fn push_node>>(&mut self, node: T) { + self.node_id = Some(self.nodes.insert(node.into(), self.node_id, self.branch_id)); } - /// Pop the current [`Stmt`] off the stack. - pub fn pop_statement(&mut self) { - let node_id = self - .statement_id - .expect("Attempted to pop without statement"); - self.statement_id = self.statements.parent_id(node_id); - } - - /// Push a [`Expr`] onto the stack. - pub fn push_expression(&mut self, expr: &'a Expr) { - self.expression_id = Some(self.expressions.insert(expr, self.expression_id)); - } - - /// Pop the current [`Expr`] off the stack. - pub fn pop_expression(&mut self) { - let node_id = self - .expression_id - .expect("Attempted to pop without expression"); - self.expression_id = self.expressions.parent_id(node_id); + /// Pop the current AST node [`NodeRef`] off the stack. + pub fn pop_node(&mut self) { + let node_id = self.node_id.expect("Attempted to pop without node"); + self.node_id = self.nodes.parent_id(node_id); } /// Push a [`Scope`] with the given [`ScopeKind`] onto the stack. @@ -860,34 +833,20 @@ impl<'a> SemanticModel<'a> { self.branch_id = branch_id; } - /// Returns an [`Iterator`] over the current statement hierarchy represented as [`StatementId`], - /// from the current [`StatementId`] through to any parents. - pub fn current_statement_ids(&self) -> impl Iterator + '_ { - self.statement_id - .iter() - .flat_map(|id| self.statements.ancestor_ids(*id)) - } - /// Returns an [`Iterator`] over the current statement hierarchy, from the current [`Stmt`] /// through to any parents. pub fn current_statements(&self) -> impl Iterator + '_ { - self.current_statement_ids().map(|id| self.statements[id]) - } - - /// Return the [`StatementId`] of the current [`Stmt`]. - pub fn current_statement_id(&self) -> StatementId { - self.statement_id.expect("No current statement") - } - - /// Return the [`StatementId`] of the current [`Stmt`] parent, if any. - pub fn current_statement_parent_id(&self) -> Option { - self.current_statement_ids().nth(1) + let id = self.node_id.expect("No current node"); + self.nodes + .ancestor_ids(id) + .filter_map(move |id| self.nodes[id].as_statement()) } /// Return the current [`Stmt`]. pub fn current_statement(&self) -> &'a Stmt { - let node_id = self.statement_id.expect("No current statement"); - self.statements[node_id] + self.current_statements() + .next() + .expect("No current statement") } /// Return the parent [`Stmt`] of the current [`Stmt`], if any. @@ -895,24 +854,18 @@ impl<'a> SemanticModel<'a> { self.current_statements().nth(1) } - /// Returns an [`Iterator`] over the current expression hierarchy represented as - /// [`ExpressionId`], from the current [`Expr`] through to any parents. - pub fn current_expression_ids(&self) -> impl Iterator + '_ { - self.expression_id - .iter() - .flat_map(|id| self.expressions.ancestor_ids(*id)) - } - /// Returns an [`Iterator`] over the current expression hierarchy, from the current [`Expr`] /// through to any parents. pub fn current_expressions(&self) -> impl Iterator + '_ { - self.current_expression_ids().map(|id| self.expressions[id]) + let id = self.node_id.expect("No current node"); + self.nodes + .ancestor_ids(id) + .filter_map(move |id| self.nodes[id].as_expression()) } /// Return the current [`Expr`]. pub fn current_expression(&self) -> Option<&'a Expr> { - let node_id = self.expression_id?; - Some(self.expressions[node_id]) + self.current_expressions().next() } /// Return the parent [`Expr`] of the current [`Expr`], if any. @@ -925,6 +878,27 @@ impl<'a> SemanticModel<'a> { self.current_expressions().nth(2) } + /// Returns an [`Iterator`] over the current statement hierarchy represented as [`NodeId`], + /// from the current [`NodeId`] through to any parents. + pub fn current_statement_ids(&self) -> impl Iterator + '_ { + self.node_id + .iter() + .flat_map(|id| self.nodes.ancestor_ids(*id)) + .filter(|id| self.nodes[*id].is_statement()) + } + + /// Return the [`NodeId`] of the current [`Stmt`]. + pub fn current_statement_id(&self) -> NodeId { + self.current_statement_ids() + .next() + .expect("No current statement") + } + + /// Return the [`NodeId`] of the current [`Stmt`] parent, if any. + pub fn current_statement_parent_id(&self) -> Option { + self.current_statement_ids().nth(1) + } + /// Returns a reference to the global [`Scope`]. pub fn global_scope(&self) -> &Scope<'a> { self.scopes.global() @@ -973,24 +947,36 @@ impl<'a> SemanticModel<'a> { None } - /// Return the [`Stmt]` corresponding to the given [`StatementId`]. + /// Return the [`Stmt`] corresponding to the given [`NodeId`]. #[inline] - pub fn statement(&self, statement_id: StatementId) -> &'a Stmt { - self.statements[statement_id] + pub fn node(&self, node_id: NodeId) -> &NodeRef<'a> { + &self.nodes[node_id] + } + + /// Return the [`Stmt`] corresponding to the given [`NodeId`]. + #[inline] + pub fn statement(&self, node_id: NodeId) -> &'a Stmt { + self.nodes + .ancestor_ids(node_id) + .find_map(|id| self.nodes[id].as_statement()) + .expect("No statement found") } /// Given a [`Stmt`], return its parent, if any. #[inline] - pub fn parent_statement(&self, statement_id: StatementId) -> Option<&'a Stmt> { - self.statements - .parent_id(statement_id) - .map(|id| self.statements[id]) + pub fn parent_statement(&self, node_id: NodeId) -> Option<&'a Stmt> { + self.nodes + .ancestor_ids(node_id) + .filter_map(|id| self.nodes[id].as_statement()) + .nth(1) } - /// Given a [`StatementId`], return the ID of its parent statement, if any. - #[inline] - pub fn parent_statement_id(&self, statement_id: StatementId) -> Option { - self.statements.parent_id(statement_id) + /// Given a [`NodeId`], return the [`NodeId`] of the parent statement, if any. + pub fn parent_statement_id(&self, node_id: NodeId) -> Option { + self.nodes + .ancestor_ids(node_id) + .filter(|id| self.nodes[*id].is_statement()) + .nth(1) } /// Set the [`Globals`] for the current [`Scope`]. @@ -1007,7 +993,7 @@ impl<'a> SemanticModel<'a> { range: *range, references: Vec::new(), scope: self.scope_id, - source: self.statement_id, + source: self.node_id, context: self.execution_context(), exceptions: self.exceptions(), flags: BindingFlags::empty(), @@ -1053,10 +1039,7 @@ impl<'a> SemanticModel<'a> { /// Return `true` if the model is at the top level of the module (i.e., in the module scope, /// and not nested within any statements). pub fn at_top_level(&self) -> bool { - self.scope_id.is_global() - && self - .statement_id - .map_or(true, |stmt_id| self.statements.parent_id(stmt_id).is_none()) + self.scope_id.is_global() && self.current_statement_parent_id().is_none() } /// Return `true` if the model is in an async context. @@ -1101,10 +1084,10 @@ impl<'a> SemanticModel<'a> { /// `try` statement. /// /// This implementation assumes that the statements are in the same scope. - pub fn different_branches(&self, left: StatementId, right: StatementId) -> bool { + pub fn different_branches(&self, left: NodeId, right: NodeId) -> bool { // Collect the branch path for the left statement. let left = self - .statements + .nodes .branch_id(left) .iter() .flat_map(|branch_id| self.branches.ancestor_ids(*branch_id)) @@ -1112,7 +1095,7 @@ impl<'a> SemanticModel<'a> { // Collect the branch path for the right statement. let right = self - .statements + .nodes .branch_id(right) .iter() .flat_map(|branch_id| self.branches.ancestor_ids(*branch_id)) @@ -1191,8 +1174,7 @@ impl<'a> SemanticModel<'a> { pub fn snapshot(&self) -> Snapshot { Snapshot { scope_id: self.scope_id, - stmt_id: self.statement_id, - expr_id: self.expression_id, + node_id: self.node_id, branch_id: self.branch_id, definition_id: self.definition_id, flags: self.flags, @@ -1203,15 +1185,13 @@ impl<'a> SemanticModel<'a> { pub fn restore(&mut self, snapshot: Snapshot) { let Snapshot { scope_id, - stmt_id, - expr_id, + node_id, branch_id, definition_id, flags, } = snapshot; self.scope_id = scope_id; - self.statement_id = stmt_id; - self.expression_id = expr_id; + self.node_id = node_id; self.branch_id = branch_id; self.definition_id = definition_id; self.flags = flags; @@ -1625,8 +1605,7 @@ impl SemanticModelFlags { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Snapshot { scope_id: ScopeId, - stmt_id: Option, - expr_id: Option, + node_id: Option, branch_id: Option, definition_id: DefinitionId, flags: SemanticModelFlags, diff --git a/crates/ruff_python_semantic/src/nodes.rs b/crates/ruff_python_semantic/src/nodes.rs new file mode 100644 index 0000000000..506623c57e --- /dev/null +++ b/crates/ruff_python_semantic/src/nodes.rs @@ -0,0 +1,136 @@ +use std::ops::Index; + +use ruff_index::{newtype_index, IndexVec}; +use ruff_python_ast::{Expr, Ranged, Stmt}; +use ruff_text_size::TextRange; + +use crate::BranchId; + +/// Id uniquely identifying an AST node in a program. +/// +/// Using a `u32` is sufficient because Ruff only supports parsing documents with a size of max +/// `u32::max` and it is impossible to have more nodes than characters in the file. We use a +/// `NonZeroU32` to take advantage of memory layout optimizations. +#[newtype_index] +#[derive(Ord, PartialOrd)] +pub struct NodeId; + +/// An AST node in a program, along with a pointer to its parent node (if any). +#[derive(Debug)] +struct NodeWithParent<'a> { + /// A pointer to the AST node. + node: NodeRef<'a>, + /// The ID of the parent of this node, if any. + parent: Option, + /// The branch ID of this node, if any. + branch: Option, +} + +/// The nodes of a program indexed by [`NodeId`] +#[derive(Debug, Default)] +pub struct Nodes<'a> { + nodes: IndexVec>, +} + +impl<'a> Nodes<'a> { + /// Inserts a new AST node into the tree and returns its unique ID. + pub(crate) fn insert( + &mut self, + node: NodeRef<'a>, + parent: Option, + branch: Option, + ) -> NodeId { + self.nodes.push(NodeWithParent { + node, + parent, + branch, + }) + } + + /// Return the [`NodeId`] of the parent node. + #[inline] + pub fn parent_id(&self, node_id: NodeId) -> Option { + self.nodes[node_id].parent + } + + /// Return the [`BranchId`] of the branch node. + #[inline] + pub(crate) fn branch_id(&self, node_id: NodeId) -> Option { + self.nodes[node_id].branch + } + + /// Returns an iterator over all [`NodeId`] ancestors, starting from the given [`NodeId`]. + pub(crate) fn ancestor_ids(&self, node_id: NodeId) -> impl Iterator + '_ { + std::iter::successors(Some(node_id), |&node_id| self.nodes[node_id].parent) + } +} + +impl<'a> Index for Nodes<'a> { + type Output = NodeRef<'a>; + + #[inline] + fn index(&self, index: NodeId) -> &Self::Output { + &self.nodes[index].node + } +} + +/// A reference to an AST node. Like [`ruff_python_ast::node::AnyNodeRef`], but wraps the node +/// itself (like [`Stmt`]) rather than the narrowed type (like [`ruff_python_ast::StmtAssign`]). +/// +/// TODO(charlie): Replace with [`ruff_python_ast::node::AnyNodeRef`]. This requires migrating +/// the rest of the codebase to use [`ruff_python_ast::node::AnyNodeRef`] and related abstractions, +/// like [`ruff_python_ast::ExpressionRef`] instead of [`Expr`]. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum NodeRef<'a> { + Stmt(&'a Stmt), + Expr(&'a Expr), +} + +impl<'a> NodeRef<'a> { + /// Returns the [`Stmt`] if this is a statement, or `None` if the reference is to another + /// kind of AST node. + pub fn as_statement(&self) -> Option<&'a Stmt> { + match self { + NodeRef::Stmt(stmt) => Some(stmt), + NodeRef::Expr(_) => None, + } + } + + /// Returns the [`Expr`] if this is a expression, or `None` if the reference is to another + /// kind of AST node. + pub fn as_expression(&self) -> Option<&'a Expr> { + match self { + NodeRef::Stmt(_) => None, + NodeRef::Expr(expr) => Some(expr), + } + } + + pub fn is_statement(&self) -> bool { + self.as_statement().is_some() + } + + pub fn is_expression(&self) -> bool { + self.as_expression().is_some() + } +} + +impl Ranged for NodeRef<'_> { + fn range(&self) -> TextRange { + match self { + NodeRef::Stmt(stmt) => stmt.range(), + NodeRef::Expr(expr) => expr.range(), + } + } +} + +impl<'a> From<&'a Expr> for NodeRef<'a> { + fn from(expr: &'a Expr) -> Self { + NodeRef::Expr(expr) + } +} + +impl<'a> From<&'a Stmt> for NodeRef<'a> { + fn from(stmt: &'a Stmt) -> Self { + NodeRef::Stmt(stmt) + } +} diff --git a/crates/ruff_python_semantic/src/statements.rs b/crates/ruff_python_semantic/src/statements.rs deleted file mode 100644 index 0b3f054aed..0000000000 --- a/crates/ruff_python_semantic/src/statements.rs +++ /dev/null @@ -1,72 +0,0 @@ -use std::ops::Index; - -use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast::Stmt; - -use crate::branches::BranchId; - -/// Id uniquely identifying a statement AST node. -/// -/// Using a `u32` is sufficient because Ruff only supports parsing documents with a size of max -/// `u32::max` and it is impossible to have more nodes than characters in the file. We use a -/// `NonZeroU32` to take advantage of memory layout optimizations. -#[newtype_index] -#[derive(Ord, PartialOrd)] -pub struct StatementId; - -/// A [`Stmt`] AST node, along with a pointer to its parent statement (if any). -#[derive(Debug)] -struct StatementWithParent<'a> { - /// A pointer to the AST node. - statement: &'a Stmt, - /// The ID of the parent of this node, if any. - parent: Option, - /// The branch ID of this node, if any. - branch: Option, -} - -/// The statements of a program indexed by [`StatementId`] -#[derive(Debug, Default)] -pub struct Statements<'a>(IndexVec>); - -impl<'a> Statements<'a> { - /// Inserts a new statement into the statement vector and returns its unique ID. - pub(crate) fn insert( - &mut self, - statement: &'a Stmt, - parent: Option, - branch: Option, - ) -> StatementId { - self.0.push(StatementWithParent { - statement, - parent, - branch, - }) - } - - /// Return the [`StatementId`] of the parent statement. - #[inline] - pub(crate) fn parent_id(&self, statement_id: StatementId) -> Option { - self.0[statement_id].parent - } - - /// Return the [`StatementId`] of the parent statement. - #[inline] - pub(crate) fn branch_id(&self, statement_id: StatementId) -> Option { - self.0[statement_id].branch - } - - /// Returns an iterator over all [`StatementId`] ancestors, starting from the given [`StatementId`]. - pub(crate) fn ancestor_ids(&self, id: StatementId) -> impl Iterator + '_ { - std::iter::successors(Some(id), |&id| self.0[id].parent) - } -} - -impl<'a> Index for Statements<'a> { - type Output = &'a Stmt; - - #[inline] - fn index(&self, index: StatementId) -> &Self::Output { - &self.0[index].statement - } -}