From b0b4706e2da64ea80c4deccea7f0a01f2bed5d5c Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 5 Jun 2024 17:53:26 +0200 Subject: [PATCH] Red-knot: Track scopes per expression (#11754) --- crates/red_knot/src/semantic.rs | 56 ++++++++++++++++++-- crates/red_knot/src/semantic/flow_graph.rs | 23 ++++---- crates/red_knot/src/semantic/symbol_table.rs | 42 +++++++++++---- crates/ruff_index/src/slice.rs | 18 ++++++- crates/ruff_index/src/vec.rs | 5 ++ 5 files changed, 118 insertions(+), 26 deletions(-) diff --git a/crates/red_knot/src/semantic.rs b/crates/red_knot/src/semantic.rs index 1ba9d99e60..b02f982641 100644 --- a/crates/red_knot/src/semantic.rs +++ b/crates/red_knot/src/semantic.rs @@ -12,6 +12,8 @@ use crate::module::ModuleName; use crate::parse::parse; use crate::Name; use flow_graph::{FlowGraph, FlowGraphBuilder, FlowNodeId, ReachableDefinitionsIterator}; +use ruff_index::newtype_index; +use rustc_hash::FxHashMap; use std::ops::{Deref, DerefMut}; use std::sync::Arc; pub(crate) use symbol_table::{Definition, Dependency, SymbolId}; @@ -49,6 +51,9 @@ pub fn resolve_global_symbol( Ok(Some(GlobalSymbolId { file_id, symbol_id })) } +#[newtype_index] +pub struct ExpressionId; + #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct GlobalSymbolId { pub(crate) file_id: FileId, @@ -59,6 +64,7 @@ pub struct GlobalSymbolId { pub struct SemanticIndex { symbol_table: SymbolTable, flow_graph: FlowGraph, + expressions: FxHashMap, } impl SemanticIndex { @@ -71,6 +77,7 @@ impl SemanticIndex { scope_id: root_scope_id, current_flow_node_id: FlowGraph::start(), }], + expressions: FxHashMap::default(), current_definition: None, }; indexer.visit_body(&module.body); @@ -85,13 +92,18 @@ impl SemanticIndex { symbol_id: SymbolId, use_expr: &ast::Expr, ) -> ReachableDefinitionsIterator { + let expression_id = self.expression_id(use_expr); ReachableDefinitionsIterator::new( &self.flow_graph, symbol_id, - self.flow_graph.for_expr(use_expr), + self.flow_graph.for_expr(expression_id), ) } + pub fn expression_id(&self, expression: &ast::Expr) -> ExpressionId { + self.expressions[&NodeKey::from_node(expression.into())] + } + pub fn symbol_table(&self) -> &SymbolTable { &self.symbol_table } @@ -110,6 +122,7 @@ struct SemanticIndexer { scopes: Vec, /// the definition whose target(s) we are currently walking current_definition: Option, + expressions: FxHashMap, } impl SemanticIndexer { @@ -122,6 +135,7 @@ impl SemanticIndexer { SemanticIndex { flow_graph: flow_graph_builder.finish(), symbol_table: symbol_table_builder.finish(), + expressions: self.expressions, } } @@ -240,8 +254,19 @@ impl PreorderVisitor<'_> for SemanticIndexer { } } } - self.flow_graph_builder - .record_expr(expr, self.current_flow_node()); + + let expression_id = self + .flow_graph_builder + .record_expr(self.current_flow_node()); + + debug_assert_eq!( + expression_id, + self.symbol_table_builder + .record_expression(self.cur_scope()) + ); + + self.expressions + .insert(NodeKey::from_node(expr.into()), expression_id); ast::visitor::preorder::walk_expr(self, expr); } @@ -744,4 +769,29 @@ mod tests { }; assert_eq!(*num, 1); } + + #[test] + fn expression_scope() { + let parsed = parse("x = 1;\ndef test():\n y = 4"); + let ast = parsed.syntax(); + let index = SemanticIndex::from_ast(ast); + let table = &index.symbol_table; + + let x_sym = table + .root_symbol_by_name("x") + .expect("x symbol should exist"); + + let x_stmt = ast.body[0].as_assign_stmt().unwrap(); + + let x_id = index.expression_id(&x_stmt.targets[0]); + + assert_eq!(table.scope_of_expression(x_id).kind(), ScopeKind::Module); + assert_eq!(table.scope_id_of_expression(x_id), x_sym.scope_id()); + + let def = ast.body[1].as_function_def_stmt().unwrap(); + let y_stmt = def.body[0].as_assign_stmt().unwrap(); + let y_id = index.expression_id(&y_stmt.targets[0]); + + assert_eq!(table.scope_of_expression(y_id).kind(), ScopeKind::Function); + } } diff --git a/crates/red_knot/src/semantic/flow_graph.rs b/crates/red_knot/src/semantic/flow_graph.rs index d04a4ece76..85b1931cfc 100644 --- a/crates/red_knot/src/semantic/flow_graph.rs +++ b/crates/red_knot/src/semantic/flow_graph.rs @@ -1,8 +1,6 @@ use super::symbol_table::{Definition, SymbolId}; -use crate::ast_ids::NodeKey; +use crate::semantic::ExpressionId; use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast as ast; -use rustc_hash::FxHashMap; use std::iter::FusedIterator; #[newtype_index] @@ -40,7 +38,7 @@ pub(crate) struct PhiFlowNode { #[derive(Debug)] pub struct FlowGraph { flow_nodes_by_id: IndexVec, - ast_to_flow: FxHashMap, + expression_map: IndexVec, } impl FlowGraph { @@ -48,9 +46,8 @@ impl FlowGraph { FlowNodeId::from_usize(0) } - pub fn for_expr(&self, expr: &ast::Expr) -> FlowNodeId { - let node_key = NodeKey::from_node(expr.into()); - self.ast_to_flow[&node_key] + pub fn for_expr(&self, expr: ExpressionId) -> FlowNodeId { + self.expression_map[expr] } } @@ -63,7 +60,7 @@ impl FlowGraphBuilder { pub(crate) fn new() -> Self { let mut graph = FlowGraph { flow_nodes_by_id: IndexVec::default(), - ast_to_flow: FxHashMap::default(), + expression_map: IndexVec::default(), }; graph.flow_nodes_by_id.push(FlowNode::Start); Self { flow_graph: graph } @@ -101,13 +98,13 @@ impl FlowGraphBuilder { })) } - pub(crate) fn record_expr(&mut self, expr: &ast::Expr, node_id: FlowNodeId) { - self.flow_graph - .ast_to_flow - .insert(NodeKey::from_node(expr.into()), node_id); + pub(super) fn record_expr(&mut self, node_id: FlowNodeId) -> ExpressionId { + self.flow_graph.expression_map.push(node_id) } - pub(crate) fn finish(self) -> FlowGraph { + pub(super) fn finish(mut self) -> FlowGraph { + self.flow_graph.flow_nodes_by_id.shrink_to_fit(); + self.flow_graph.expression_map.shrink_to_fit(); self.flow_graph } } diff --git a/crates/red_knot/src/semantic/symbol_table.rs b/crates/red_knot/src/semantic/symbol_table.rs index ece4e95758..67780ceeed 100644 --- a/crates/red_knot/src/semantic/symbol_table.rs +++ b/crates/red_knot/src/semantic/symbol_table.rs @@ -13,6 +13,7 @@ use ruff_python_ast as ast; use crate::ast_ids::{NodeKey, TypedNodeKey}; use crate::module::ModuleName; +use crate::semantic::ExpressionId; use crate::Name; type Map = hashbrown::HashMap; @@ -192,6 +193,8 @@ pub struct SymbolTable { defs: FxHashMap>, /// map of AST node (e.g. class/function def) to sub-scope it creates scopes_by_node: FxHashMap, + /// Maps expressions to their enclosing scope. + expression_scopes: IndexVec, /// dependencies of this module dependencies: Vec, } @@ -283,6 +286,14 @@ impl SymbolTable { &self.scopes_by_id[self.scope_id_of_symbol(symbol_id)] } + pub fn scope_id_of_expression(&self, expression: ExpressionId) -> ScopeId { + self.expression_scopes[expression] + } + + pub fn scope_of_expression(&self, expr_id: ExpressionId) -> &Scope { + &self.scopes_by_id[self.scope_id_of_expression(expr_id)] + } + pub fn parent_scopes( &self, scope_id: ScopeId, @@ -393,17 +404,18 @@ where } #[derive(Debug)] -pub(crate) struct SymbolTableBuilder { +pub(super) struct SymbolTableBuilder { symbol_table: SymbolTable, } impl SymbolTableBuilder { - pub(crate) fn new() -> Self { + pub(super) fn new() -> Self { let mut table = SymbolTable { scopes_by_id: IndexVec::new(), symbols_by_id: IndexVec::new(), defs: FxHashMap::default(), scopes_by_node: FxHashMap::default(), + expression_scopes: IndexVec::new(), dependencies: Vec::new(), }; table.scopes_by_id.push(Scope { @@ -420,11 +432,18 @@ impl SymbolTableBuilder { } } - pub(crate) fn finish(self) -> SymbolTable { - self.symbol_table + pub(super) fn finish(self) -> SymbolTable { + let mut symbol_table = self.symbol_table; + symbol_table.scopes_by_id.shrink_to_fit(); + symbol_table.symbols_by_id.shrink_to_fit(); + symbol_table.defs.shrink_to_fit(); + symbol_table.scopes_by_node.shrink_to_fit(); + symbol_table.expression_scopes.shrink_to_fit(); + symbol_table.dependencies.shrink_to_fit(); + symbol_table } - pub(crate) fn add_or_update_symbol( + pub(super) fn add_or_update_symbol( &mut self, scope_id: ScopeId, name: &str, @@ -462,7 +481,7 @@ impl SymbolTableBuilder { } } - pub(crate) fn add_definition(&mut self, symbol_id: SymbolId, definition: Definition) { + pub(super) fn add_definition(&mut self, symbol_id: SymbolId, definition: Definition) { self.symbol_table .defs .entry(symbol_id) @@ -470,7 +489,7 @@ impl SymbolTableBuilder { .push(definition); } - pub(crate) fn add_child_scope( + pub(super) fn add_child_scope( &mut self, parent_scope_id: ScopeId, name: &str, @@ -492,13 +511,18 @@ impl SymbolTableBuilder { new_scope_id } - pub(crate) fn record_scope_for_node(&mut self, node_key: NodeKey, scope_id: ScopeId) { + pub(super) fn record_scope_for_node(&mut self, node_key: NodeKey, scope_id: ScopeId) { self.symbol_table.scopes_by_node.insert(node_key, scope_id); } - pub(crate) fn add_dependency(&mut self, dependency: Dependency) { + pub(super) fn add_dependency(&mut self, dependency: Dependency) { self.symbol_table.dependencies.push(dependency); } + + /// Records the scope for the current expression + pub(super) fn record_expression(&mut self, scope: ScopeId) -> ExpressionId { + self.symbol_table.expression_scopes.push(scope) + } } #[cfg(test)] diff --git a/crates/ruff_index/src/slice.rs b/crates/ruff_index/src/slice.rs index a6d3b033df..804aa1fbda 100644 --- a/crates/ruff_index/src/slice.rs +++ b/crates/ruff_index/src/slice.rs @@ -2,7 +2,7 @@ use crate::vec::IndexVec; use crate::Idx; use std::fmt::{Debug, Formatter}; use std::marker::PhantomData; -use std::ops::{Index, IndexMut}; +use std::ops::{Index, IndexMut, Range}; /// A view into contiguous `T`s, indexed by `I` rather than by `usize`. #[derive(PartialEq, Eq, Hash)] @@ -131,6 +131,15 @@ impl Index for IndexSlice { } } +impl Index> for IndexSlice { + type Output = [T]; + + #[inline] + fn index(&self, range: Range) -> &[T] { + &self.raw[range.start.index()..range.end.index()] + } +} + impl IndexMut for IndexSlice { #[inline] fn index_mut(&mut self, index: I) -> &mut T { @@ -138,6 +147,13 @@ impl IndexMut for IndexSlice { } } +impl IndexMut> for IndexSlice { + #[inline] + fn index_mut(&mut self, range: Range) -> &mut [T] { + &mut self.raw[range.start.index()..range.end.index()] + } +} + impl<'a, I: Idx, T> IntoIterator for &'a IndexSlice { type IntoIter = std::slice::Iter<'a, T>; type Item = &'a T; diff --git a/crates/ruff_index/src/vec.rs b/crates/ruff_index/src/vec.rs index 516e53487f..795f8315d4 100644 --- a/crates/ruff_index/src/vec.rs +++ b/crates/ruff_index/src/vec.rs @@ -69,6 +69,11 @@ impl IndexVec { pub fn next_index(&self) -> I { I::new(self.raw.len()) } + + #[inline] + pub fn shrink_to_fit(&mut self) { + self.raw.shrink_to_fit(); + } } impl Debug for IndexVec