From f666d79cd76b8d60ebba5e66f19bb08330cb2879 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 18 Jun 2024 14:10:45 +0100 Subject: [PATCH] red-knot: Symbol table (#11860) --- Cargo.lock | 2 + crates/ruff_db/src/parsed.rs | 10 +- crates/ruff_python_semantic/Cargo.toml | 6 +- crates/ruff_python_semantic/src/db.rs | 22 +- crates/ruff_python_semantic/src/lib.rs | 3 + crates/ruff_python_semantic/src/name.rs | 56 ++ .../src/red_knot/ast_node_ref.rs | 162 +++++ .../ruff_python_semantic/src/red_knot/mod.rs | 3 + .../src/red_knot/node_key.rs | 24 + .../src/red_knot/semantic_index.rs | 655 ++++++++++++++++++ .../src/red_knot/semantic_index/ast_ids.rs | 384 ++++++++++ .../src/red_knot/semantic_index/builder.rs | 398 +++++++++++ .../src/red_knot/semantic_index/definition.rs | 76 ++ .../src/red_knot/semantic_index/symbol.rs | 362 ++++++++++ 14 files changed, 2153 insertions(+), 10 deletions(-) create mode 100644 crates/ruff_python_semantic/src/name.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/ast_node_ref.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/mod.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/node_key.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/semantic_index.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs diff --git a/Cargo.lock b/Cargo.lock index 0f9ecd2db3..e95c63f0fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2470,6 +2470,7 @@ version = "0.0.0" dependencies = [ "anyhow", "bitflags 2.5.0", + "hashbrown 0.14.5", "is-macro", "ruff_db", "ruff_index", @@ -2480,6 +2481,7 @@ dependencies = [ "ruff_text_size", "rustc-hash", "salsa-2022", + "smallvec", "smol_str", "tempfile", "tracing", diff --git a/crates/ruff_db/src/parsed.rs b/crates/ruff_db/src/parsed.rs index 72e3b2c9f4..1241f0628c 100644 --- a/crates/ruff_db/src/parsed.rs +++ b/crates/ruff_db/src/parsed.rs @@ -32,9 +32,7 @@ pub fn parsed_module(db: &dyn Db, file: VfsFile) -> ParsedModule { VfsPath::Vendored(_) => PySourceType::Stub, }; - ParsedModule { - inner: Arc::new(parse_unchecked_source(&source, ty)), - } + ParsedModule::new(parse_unchecked_source(&source, ty)) } /// Cheap cloneable wrapper around the parsed module. @@ -44,6 +42,12 @@ pub struct ParsedModule { } impl ParsedModule { + pub fn new(parsed: Parsed) -> Self { + Self { + inner: Arc::new(parsed), + } + } + /// Consumes `self` and returns the Arc storing the parsed module. pub fn into_arc(self) -> Arc> { self.inner diff --git a/crates/ruff_python_semantic/Cargo.toml b/crates/ruff_python_semantic/Cargo.toml index b0650195e0..1b767e8e95 100644 --- a/crates/ruff_python_semantic/Cargo.toml +++ b/crates/ruff_python_semantic/Cargo.toml @@ -21,9 +21,11 @@ ruff_text_size = { workspace = true } bitflags = { workspace = true } is-macro = { workspace = true } salsa = { workspace = true, optional = true } -smol_str = { workspace = true, optional = true } +smallvec = { workspace = true, optional = true } +smol_str = { workspace = true } tracing = { workspace = true, optional = true } rustc-hash = { workspace = true } +hashbrown = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -34,4 +36,4 @@ tempfile = { workspace = true } workspace = true [features] -red_knot = ["dep:salsa", "dep:smol_str", "dep:tracing"] +red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec"] diff --git a/crates/ruff_python_semantic/src/db.rs b/crates/ruff_python_semantic/src/db.rs index 4765eb5c48..ae3d71ea01 100644 --- a/crates/ruff_python_semantic/src/db.rs +++ b/crates/ruff_python_semantic/src/db.rs @@ -1,16 +1,25 @@ +use salsa::DbWithJar; + +use ruff_db::{Db as SourceDb, Upcast}; + use crate::module::resolver::{ file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths, resolve_module_query, }; -use ruff_db::{Db as SourceDb, Upcast}; -use salsa::DbWithJar; + +use crate::red_knot::semantic_index::symbol::ScopeId; +use crate::red_knot::semantic_index::{scopes_map, semantic_index, symbol_table}; #[salsa::jar(db=Db)] pub struct Jar( ModuleNameIngredient, ModuleResolverSearchPaths, + ScopeId, + symbol_table, resolve_module_query, file_to_module, + scopes_map, + semantic_index, ); /// Database giving access to semantic information about a Python program. @@ -18,12 +27,15 @@ pub trait Db: SourceDb + DbWithJar + Upcast {} #[cfg(test)] pub(crate) mod tests { - use super::{Db, Jar}; + use std::sync::Arc; + + use salsa::DebugWithDb; + use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem}; use ruff_db::vfs::Vfs; use ruff_db::{Db as SourceDb, Jar as SourceJar, Upcast}; - use salsa::DebugWithDb; - use std::sync::Arc; + + use super::{Db, Jar}; #[salsa::db(Jar, SourceJar)] pub(crate) struct TestDb { diff --git a/crates/ruff_python_semantic/src/lib.rs b/crates/ruff_python_semantic/src/lib.rs index 65f5ae1b3c..4f30103e39 100644 --- a/crates/ruff_python_semantic/src/lib.rs +++ b/crates/ruff_python_semantic/src/lib.rs @@ -9,7 +9,10 @@ mod globals; mod model; #[cfg(feature = "red_knot")] pub mod module; +pub mod name; mod nodes; +#[cfg(feature = "red_knot")] +pub mod red_knot; mod reference; mod scope; mod star_import; diff --git a/crates/ruff_python_semantic/src/name.rs b/crates/ruff_python_semantic/src/name.rs new file mode 100644 index 0000000000..78a9e4cfc2 --- /dev/null +++ b/crates/ruff_python_semantic/src/name.rs @@ -0,0 +1,56 @@ +use std::ops::Deref; + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Name(smol_str::SmolStr); + +impl Name { + #[inline] + pub fn new(name: &str) -> Self { + Self(smol_str::SmolStr::new(name)) + } + + #[inline] + pub fn new_static(name: &'static str) -> Self { + Self(smol_str::SmolStr::new_static(name)) + } + + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} + +impl Deref for Name { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl From for Name +where + T: Into, +{ + fn from(value: T) -> Self { + Self(value.into()) + } +} + +impl std::fmt::Display for Name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl PartialEq for Name { + fn eq(&self, other: &str) -> bool { + self.as_str() == other + } +} + +impl PartialEq for str { + fn eq(&self, other: &Name) -> bool { + other == self + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/ast_node_ref.rs b/crates/ruff_python_semantic/src/red_knot/ast_node_ref.rs new file mode 100644 index 0000000000..b3e58e2237 --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/ast_node_ref.rs @@ -0,0 +1,162 @@ +use std::hash::Hash; +use std::ops::Deref; + +use ruff_db::parsed::ParsedModule; + +/// Ref-counted owned reference to an AST node. +/// +/// The type holds an owned reference to the node's ref-counted [`ParsedModule`]. +/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the +/// node must still be valid. +/// +/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released. +/// +/// ## Equality +/// Two `AstNodeRef` are considered equal if their wrapped nodes are equal. +#[derive(Clone)] +pub struct AstNodeRef { + /// Owned reference to the node's [`ParsedModule`]. + /// + /// The node's reference is guaranteed to remain valid as long as it's enclosing + /// [`ParsedModule`] is alive. + _parsed: ParsedModule, + + /// Pointer to the referenced node. + node: std::ptr::NonNull, +} + +#[allow(unsafe_code)] +impl AstNodeRef { + /// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to which + /// the `AstNodeRef` belongs. + /// + /// ## Safety + /// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the [`ParsedModule`] to + /// which `node` belongs. It's the caller's responsibility to ensure that the invariant `node belongs to parsed` is upheld. + + pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self { + Self { + _parsed: parsed, + node: std::ptr::NonNull::from(node), + } + } + + /// Returns a reference to the wrapped node. + pub fn node(&self) -> &T { + // SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still alive + // and not moved. + unsafe { self.node.as_ref() } + } +} + +impl Deref for AstNodeRef { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.node() + } +} + +impl std::fmt::Debug for AstNodeRef +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("AstNodeRef").field(&self.node()).finish() + } +} + +impl PartialEq for AstNodeRef +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.node().eq(other.node()) + } +} + +impl Eq for AstNodeRef where T: Eq {} + +impl Hash for AstNodeRef +where + T: Hash, +{ + fn hash(&self, state: &mut H) { + self.node().hash(state); + } +} + +#[allow(unsafe_code)] +unsafe impl Send for AstNodeRef where T: Send {} +#[allow(unsafe_code)] +unsafe impl Sync for AstNodeRef where T: Sync {} + +#[cfg(test)] +mod tests { + use crate::red_knot::ast_node_ref::AstNodeRef; + use ruff_db::parsed::ParsedModule; + use ruff_python_ast::PySourceType; + use ruff_python_parser::parse_unchecked_source; + + #[test] + #[allow(unsafe_code)] + fn equality() { + let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python); + let parsed = ParsedModule::new(parsed_raw.clone()); + + let stmt = &parsed.syntax().body[0]; + + let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; + let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; + + assert_eq!(node1, node2); + + // Compare from different trees + let cloned = ParsedModule::new(parsed_raw); + let stmt_cloned = &cloned.syntax().body[0]; + let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) }; + + assert_eq!(node1, cloned_node); + + let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python); + let other = ParsedModule::new(other_raw); + + let other_stmt = &other.syntax().body[0]; + let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) }; + + assert_ne!(node1, other_node); + } + + #[allow(unsafe_code)] + #[test] + fn inequality() { + let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python); + let parsed = ParsedModule::new(parsed_raw.clone()); + + let stmt = &parsed.syntax().body[0]; + let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; + + let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python); + let other = ParsedModule::new(other_raw); + + let other_stmt = &other.syntax().body[0]; + let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) }; + + assert_ne!(node, other_node); + } + + #[test] + #[allow(unsafe_code)] + fn debug() { + let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python); + let parsed = ParsedModule::new(parsed_raw.clone()); + + let stmt = &parsed.syntax().body[0]; + + let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; + + let debug = format!("{stmt_node:?}"); + + assert_eq!(debug, format!("AstNodeRef({stmt:?})")); + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/mod.rs b/crates/ruff_python_semantic/src/red_knot/mod.rs new file mode 100644 index 0000000000..9a21b4c4cf --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/mod.rs @@ -0,0 +1,3 @@ +pub mod ast_node_ref; +mod node_key; +pub mod semantic_index; diff --git a/crates/ruff_python_semantic/src/red_knot/node_key.rs b/crates/ruff_python_semantic/src/red_knot/node_key.rs new file mode 100644 index 0000000000..0935a1f839 --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/node_key.rs @@ -0,0 +1,24 @@ +use ruff_python_ast::{AnyNodeRef, NodeKind}; +use ruff_text_size::{Ranged, TextRange}; + +/// Compact key for a node for use in a hash map. +/// +/// Compares two nodes by their kind and text range. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub(super) struct NodeKey { + kind: NodeKind, + range: TextRange, +} + +impl NodeKey { + pub(super) fn from_node<'a, N>(node: N) -> Self + where + N: Into>, + { + let node = node.into(); + NodeKey { + kind: node.kind(), + range: node.range(), + } + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index.rs new file mode 100644 index 0000000000..8764a00c7b --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index.rs @@ -0,0 +1,655 @@ +use std::iter::FusedIterator; +use std::sync::Arc; + +use rustc_hash::FxHashMap; + +use ruff_db::parsed::parsed_module; +use ruff_db::vfs::VfsFile; +use ruff_index::{IndexSlice, IndexVec}; +use ruff_python_ast as ast; + +use crate::red_knot::node_key::NodeKey; +use crate::red_knot::semantic_index::ast_ids::AstIds; +use crate::red_knot::semantic_index::builder::SemanticIndexBuilder; +use crate::red_knot::semantic_index::symbol::{ + FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeSymbolId, ScopesMap, SymbolTable, +}; +use crate::Db; + +pub mod ast_ids; +mod builder; +pub mod definition; +pub mod symbol; + +type SymbolMap = hashbrown::HashMap; + +/// Returns the semantic index for `file`. +/// +/// Prefer using [`symbol_table`] when working with symbols from a single scope. +#[salsa::tracked(return_ref, no_eq)] +pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex { + let parsed = parsed_module(db.upcast(), file); + + SemanticIndexBuilder::new(parsed).build() +} + +/// Returns the symbol table for a specific `scope`. +/// +/// Using [`symbol_table`] over [`semantic_index`] has the advantage that +/// Salsa can avoid invalidating dependent queries if this scope's symbol table +/// is unchanged. +#[salsa::tracked] +pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc { + let index = semantic_index(db, scope.file(db)); + + index.symbol_table(scope.scope_id(db)) +} + +/// Returns a mapping from file specific [`FileScopeId`] to a program-wide unique [`ScopeId`]. +#[salsa::tracked(return_ref)] +pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap { + let index = semantic_index(db, file); + + let scopes: IndexVec<_, _> = index + .scopes + .indices() + .map(|id| ScopeId::new(db, file, id)) + .collect(); + + ScopesMap::new(scopes) +} + +/// Returns the root scope of `file`. +pub fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId { + FileScopeId::root().to_scope_id(db, file) +} + +/// Returns the symbol with the given name in `file`'s public scope or `None` if +/// no symbol with the given name exists. +pub fn global_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option { + let root_scope = root_scope(db, file); + root_scope.symbol(db, name) +} + +/// The symbol tables for an entire file. +#[derive(Debug)] +pub struct SemanticIndex { + /// List of all symbol tables in this file, indexed by scope. + symbol_tables: IndexVec>, + + /// List of all scopes in this file. + scopes: IndexVec, + + /// Maps expressions to their corresponding scope. + /// We can't use [`ExpressionId`] here, because the challenge is how to get from + /// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope). + expression_scopes: FxHashMap, + + /// Lookup table to map between node ids and ast nodes. + /// + /// Note: We should not depend on this map when analysing other files or + /// changing a file invalidates all dependents. + ast_ids: IndexVec, +} + +impl SemanticIndex { + /// Returns the symbol table for a specific scope. + /// + /// Use the Salsa cached [`symbol_table`] query if you only need the + /// symbol table for a single scope. + fn symbol_table(&self, scope_id: FileScopeId) -> Arc { + self.symbol_tables[scope_id].clone() + } + + pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds { + &self.ast_ids[scope_id] + } + + /// Returns the ID of the `expression`'s enclosing scope. + #[allow(unused)] + pub(crate) fn expression_scope_id(&self, expression: &ast::Expr) -> FileScopeId { + self.expression_scopes[&NodeKey::from_node(expression)] + } + + /// Returns the [`Scope`] of the `expression`'s enclosing scope. + #[allow(unused)] + pub(crate) fn expression_scope(&self, expression: &ast::Expr) -> &Scope { + &self.scopes[self.expression_scope_id(expression)] + } + + /// Returns the [`Scope`] with the given id. + #[allow(unused)] + pub(crate) fn scope(&self, id: FileScopeId) -> &Scope { + &self.scopes[id] + } + + /// Returns the id of the parent scope. + pub(crate) fn parent_scope_id(&self, scope_id: FileScopeId) -> Option { + let scope = self.scope(scope_id); + scope.parent + } + + /// Returns the parent scope of `scope_id`. + #[allow(unused)] + pub(crate) fn parent_scope(&self, scope_id: FileScopeId) -> Option<&Scope> { + Some(&self.scopes[self.parent_scope_id(scope_id)?]) + } + + /// Returns an iterator over the descendent scopes of `scope`. + #[allow(unused)] + pub(crate) fn descendent_scopes(&self, scope: FileScopeId) -> DescendentsIter { + DescendentsIter::new(self, scope) + } + + /// Returns an iterator over the direct child scopes of `scope`. + #[allow(unused)] + pub(crate) fn child_scopes(&self, scope: FileScopeId) -> ChildrenIter { + ChildrenIter::new(self, scope) + } + + /// Returns an iterator over all ancestors of `scope`, starting with `scope` itself. + #[allow(unused)] + pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter { + AncestorsIter::new(self, scope) + } +} + +/// ID that uniquely identifies an expression inside a [`Scope`]. + +pub struct AncestorsIter<'a> { + scopes: &'a IndexSlice, + next_id: Option, +} + +impl<'a> AncestorsIter<'a> { + fn new(module_symbol_table: &'a SemanticIndex, start: FileScopeId) -> Self { + Self { + scopes: &module_symbol_table.scopes, + next_id: Some(start), + } + } +} + +impl<'a> Iterator for AncestorsIter<'a> { + type Item = (FileScopeId, &'a Scope); + + fn next(&mut self) -> Option { + let current_id = self.next_id?; + let current = &self.scopes[current_id]; + self.next_id = current.parent; + + Some((current_id, current)) + } +} + +impl FusedIterator for AncestorsIter<'_> {} + +pub struct DescendentsIter<'a> { + next_id: FileScopeId, + descendents: std::slice::Iter<'a, Scope>, +} + +impl<'a> DescendentsIter<'a> { + fn new(symbol_table: &'a SemanticIndex, scope_id: FileScopeId) -> Self { + let scope = &symbol_table.scopes[scope_id]; + let scopes = &symbol_table.scopes[scope.descendents.clone()]; + + Self { + next_id: scope_id + 1, + descendents: scopes.iter(), + } + } +} + +impl<'a> Iterator for DescendentsIter<'a> { + type Item = (FileScopeId, &'a Scope); + + fn next(&mut self) -> Option { + let descendent = self.descendents.next()?; + let id = self.next_id; + self.next_id = self.next_id + 1; + + Some((id, descendent)) + } + + fn size_hint(&self) -> (usize, Option) { + self.descendents.size_hint() + } +} + +impl FusedIterator for DescendentsIter<'_> {} + +impl ExactSizeIterator for DescendentsIter<'_> {} + +pub struct ChildrenIter<'a> { + parent: FileScopeId, + descendents: DescendentsIter<'a>, +} + +impl<'a> ChildrenIter<'a> { + fn new(module_symbol_table: &'a SemanticIndex, parent: FileScopeId) -> Self { + let descendents = DescendentsIter::new(module_symbol_table, parent); + + Self { + parent, + descendents, + } + } +} + +impl<'a> Iterator for ChildrenIter<'a> { + type Item = (FileScopeId, &'a Scope); + + fn next(&mut self) -> Option { + self.descendents + .find(|(_, scope)| scope.parent == Some(self.parent)) + } +} + +impl FusedIterator for ChildrenIter<'_> {} + +#[cfg(test)] +mod tests { + use ruff_db::parsed::parsed_module; + use ruff_db::vfs::{system_path_to_file, VfsFile}; + + use crate::db::tests::TestDb; + use crate::red_knot::semantic_index::symbol::{FileScopeId, ScopeKind, SymbolTable}; + use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table}; + + struct TestCase { + db: TestDb, + file: VfsFile, + } + + fn test_case(content: impl ToString) -> TestCase { + let db = TestDb::new(); + db.memory_file_system() + .write_file("test.py", content) + .unwrap(); + + let file = system_path_to_file(&db, "test.py").unwrap(); + + TestCase { db, file } + } + + fn names(table: &SymbolTable) -> Vec<&str> { + table + .symbols() + .map(|symbol| symbol.name().as_str()) + .collect() + } + + #[test] + fn empty() { + let TestCase { db, file } = test_case(""); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), Vec::<&str>::new()); + } + + #[test] + fn simple() { + let TestCase { db, file } = test_case("x"); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["x"]); + } + + #[test] + fn annotation_only() { + let TestCase { db, file } = test_case("x: int"); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["int", "x"]); + // TODO record definition + } + + #[test] + fn import() { + let TestCase { db, file } = test_case("import foo"); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["foo"]); + let foo = root_table.symbol_by_name("foo").unwrap(); + + assert_eq!(foo.definitions().len(), 1); + } + + #[test] + fn import_sub() { + let TestCase { db, file } = test_case("import foo.bar"); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["foo"]); + } + + #[test] + fn import_as() { + let TestCase { db, file } = test_case("import foo.bar as baz"); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["baz"]); + } + + #[test] + fn import_from() { + let TestCase { db, file } = test_case("from bar import foo"); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["foo"]); + assert_eq!( + root_table + .symbol_by_name("foo") + .unwrap() + .definitions() + .len(), + 1 + ); + assert!( + root_table + .symbol_by_name("foo") + .is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }), + "symbols that are defined get the defined flag" + ); + } + + #[test] + fn assign() { + let TestCase { db, file } = test_case("x = foo"); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["foo", "x"]); + assert_eq!( + root_table.symbol_by_name("x").unwrap().definitions().len(), + 1 + ); + assert!( + root_table + .symbol_by_name("foo") + .is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }), + "a symbol used but not defined in a scope should have only the used flag" + ); + } + + #[test] + fn class_scope() { + let TestCase { db, file } = test_case( + " +class C: + x = 1 +y = 2 +", + ); + let root_table = symbol_table(&db, root_scope(&db, file)); + + assert_eq!(names(&root_table), vec!["C", "y"]); + + let index = semantic_index(&db, file); + + let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + assert_eq!(scopes.len(), 1); + + let (class_scope_id, class_scope) = scopes[0]; + assert_eq!(class_scope.kind(), ScopeKind::Class); + assert_eq!(class_scope.name(), "C"); + + let class_table = index.symbol_table(class_scope_id); + assert_eq!(names(&class_table), vec!["x"]); + assert_eq!( + class_table.symbol_by_name("x").unwrap().definitions().len(), + 1 + ); + } + + #[test] + fn function_scope() { + let TestCase { db, file } = test_case( + " +def func(): + x = 1 +y = 2 +", + ); + let index = semantic_index(&db, file); + let root_table = index.symbol_table(FileScopeId::root()); + + assert_eq!(names(&root_table), vec!["func", "y"]); + + let scopes = index.child_scopes(FileScopeId::root()).collect::>(); + assert_eq!(scopes.len(), 1); + + let (function_scope_id, function_scope) = scopes[0]; + assert_eq!(function_scope.kind(), ScopeKind::Function); + assert_eq!(function_scope.name(), "func"); + + let function_table = index.symbol_table(function_scope_id); + assert_eq!(names(&function_table), vec!["x"]); + assert_eq!( + function_table + .symbol_by_name("x") + .unwrap() + .definitions() + .len(), + 1 + ); + } + + #[test] + fn dupes() { + let TestCase { db, file } = test_case( + " +def func(): + x = 1 +def func(): + y = 2 +", + ); + let index = semantic_index(&db, file); + let root_table = index.symbol_table(FileScopeId::root()); + + assert_eq!(names(&root_table), vec!["func"]); + let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + assert_eq!(scopes.len(), 2); + + let (func_scope1_id, func_scope_1) = scopes[0]; + let (func_scope2_id, func_scope_2) = scopes[1]; + + assert_eq!(func_scope_1.kind(), ScopeKind::Function); + assert_eq!(func_scope_1.name(), "func"); + assert_eq!(func_scope_2.kind(), ScopeKind::Function); + assert_eq!(func_scope_2.name(), "func"); + + let func1_table = index.symbol_table(func_scope1_id); + let func2_table = index.symbol_table(func_scope2_id); + assert_eq!(names(&func1_table), vec!["x"]); + assert_eq!(names(&func2_table), vec!["y"]); + assert_eq!( + root_table + .symbol_by_name("func") + .unwrap() + .definitions() + .len(), + 2 + ); + } + + #[test] + fn generic_function() { + let TestCase { db, file } = test_case( + " +def func[T](): + x = 1 +", + ); + + let index = semantic_index(&db, file); + let root_table = index.symbol_table(FileScopeId::root()); + + assert_eq!(names(&root_table), vec!["func"]); + + let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + assert_eq!(scopes.len(), 1); + let (ann_scope_id, ann_scope) = scopes[0]; + + assert_eq!(ann_scope.kind(), ScopeKind::Annotation); + assert_eq!(ann_scope.name(), "func"); + let ann_table = index.symbol_table(ann_scope_id); + assert_eq!(names(&ann_table), vec!["T"]); + + let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect(); + assert_eq!(scopes.len(), 1); + let (func_scope_id, func_scope) = scopes[0]; + assert_eq!(func_scope.kind(), ScopeKind::Function); + assert_eq!(func_scope.name(), "func"); + let func_table = index.symbol_table(func_scope_id); + assert_eq!(names(&func_table), vec!["x"]); + } + + #[test] + fn generic_class() { + let TestCase { db, file } = test_case( + " +class C[T]: + x = 1 +", + ); + + let index = semantic_index(&db, file); + let root_table = index.symbol_table(FileScopeId::root()); + + assert_eq!(names(&root_table), vec!["C"]); + + let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + + assert_eq!(scopes.len(), 1); + let (ann_scope_id, ann_scope) = scopes[0]; + assert_eq!(ann_scope.kind(), ScopeKind::Annotation); + assert_eq!(ann_scope.name(), "C"); + let ann_table = index.symbol_table(ann_scope_id); + assert_eq!(names(&ann_table), vec!["T"]); + assert!( + ann_table + .symbol_by_name("T") + .is_some_and(|s| s.is_defined() && !s.is_used()), + "type parameters are defined by the scope that introduces them" + ); + + let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect(); + assert_eq!(scopes.len(), 1); + let (func_scope_id, func_scope) = scopes[0]; + + assert_eq!(func_scope.kind(), ScopeKind::Class); + assert_eq!(func_scope.name(), "C"); + assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]); + } + + // TODO: After porting the control flow graph. + // #[test] + // fn reachability_trivial() { + // let parsed = parse("x = 1; x"); + // let ast = parsed.syntax(); + // let index = SemanticIndex::from_ast(ast); + // let table = &index.symbol_table; + // let x_sym = table + // .root_symbol_id_by_name("x") + // .expect("x symbol should exist"); + // let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else { + // panic!("should be an expr") + // }; + // let x_defs: Vec<_> = index + // .reachable_definitions(x_sym, x_use) + // .map(|constrained_definition| constrained_definition.definition) + // .collect(); + // assert_eq!(x_defs.len(), 1); + // let Definition::Assignment(node_key) = &x_defs[0] else { + // panic!("def should be an assignment") + // }; + // let Some(def_node) = node_key.resolve(ast.into()) else { + // panic!("node key should resolve") + // }; + // let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + // value: ast::Number::Int(num), + // .. + // }) = &*def_node.value + // else { + // panic!("should be a number literal") + // }; + // assert_eq!(*num, 1); + // } + + #[test] + fn expression_scope() { + let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4"); + + let index = semantic_index(&db, file); + let root_table = index.symbol_table(FileScopeId::root()); + let parsed = parsed_module(&db, file); + let ast = parsed.syntax(); + + let x_sym = root_table + .symbol_by_name("x") + .expect("x symbol should exist"); + + let x_stmt = ast.body[0].as_assign_stmt().unwrap(); + let x = &x_stmt.targets[0]; + + assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module); + assert_eq!(index.expression_scope_id(x), x_sym.scope()); + + let def = ast.body[1].as_function_def_stmt().unwrap(); + let y_stmt = def.body[0].as_assign_stmt().unwrap(); + let y = &y_stmt.targets[0]; + + assert_eq!(index.expression_scope(y).kind(), ScopeKind::Function); + } + + #[test] + fn scope_iterators() { + let TestCase { db, file } = test_case( + r#" +class Test: + def foo(): + def bar(): + ... + def baz(): + pass + +def x(): + pass"#, + ); + + let index = semantic_index(&db, file); + + let descendents: Vec<_> = index + .descendent_scopes(FileScopeId::root()) + .map(|(_, scope)| scope.name().as_str()) + .collect(); + assert_eq!(descendents, vec!["Test", "foo", "bar", "baz", "x"]); + + let children: Vec<_> = index + .child_scopes(FileScopeId::root()) + .map(|(_, scope)| scope.name.as_str()) + .collect(); + assert_eq!(children, vec!["Test", "x"]); + + let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0; + let test_child_scopes: Vec<_> = index + .child_scopes(test_class) + .map(|(_, scope)| scope.name.as_str()) + .collect(); + assert_eq!(test_child_scopes, vec!["foo", "baz"]); + + let bar_scope = index + .descendent_scopes(FileScopeId::root()) + .nth(2) + .unwrap() + .0; + let ancestors: Vec<_> = index + .ancestor_scopes(bar_scope) + .map(|(_, scope)| scope.name()) + .collect(); + + assert_eq!(ancestors, vec!["bar", "foo", "Test", ""]); + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs new file mode 100644 index 0000000000..9d1fd1a989 --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs @@ -0,0 +1,384 @@ +use rustc_hash::FxHashMap; + +use ruff_db::parsed::ParsedModule; +use ruff_db::vfs::VfsFile; +use ruff_index::{newtype_index, IndexVec}; +use ruff_python_ast as ast; +use ruff_python_ast::AnyNodeRef; + +use crate::red_knot::ast_node_ref::AstNodeRef; +use crate::red_knot::node_key::NodeKey; +use crate::red_knot::semantic_index::semantic_index; +use crate::red_knot::semantic_index::symbol::{FileScopeId, ScopeId}; +use crate::Db; + +/// AST ids for a single scope. +/// +/// The motivation for building the AST ids per scope isn't about reducing invalidation because +/// the struct changes whenever the parsed AST changes. Instead, it's mainly that we can +/// build the AST ids struct when building the symbol table and also keep the property that +/// IDs of outer scopes are unaffected by changes in inner scopes. +/// +/// For example, we don't want that adding new statements to `foo` changes the statement id of `x = foo()` in: +/// +/// ```python +/// def foo(): +/// return 5 +/// +/// x = foo() +/// ``` +pub(crate) struct AstIds { + /// Maps expression ids to their expressions. + expressions: IndexVec>, + + /// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`]. + expressions_map: FxHashMap, + + statements: IndexVec>, + + statements_map: FxHashMap, +} + +impl AstIds { + fn statement_id<'a, N>(&self, node: N) -> ScopeStatementId + where + N: Into>, + { + self.statements_map[&NodeKey::from_node(node.into())] + } + + fn expression_id<'a, N>(&self, node: N) -> ScopeExpressionId + where + N: Into>, + { + self.expressions_map[&NodeKey::from_node(node.into())] + } +} + +#[allow(clippy::missing_fields_in_debug)] +impl std::fmt::Debug for AstIds { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AstIds") + .field("expressions", &self.expressions) + .field("statements", &self.statements) + .finish() + } +} + +fn ast_ids(db: &dyn Db, scope: ScopeId) -> &AstIds { + semantic_index(db, scope.file(db)).ast_ids(scope.scope_id(db)) +} + +/// Node that can be uniquely identified by an id in a [`FileScopeId`]. +pub trait ScopeAstIdNode { + /// The type of the ID uniquely identifying the node. + type Id; + + /// Returns the ID that uniquely identifies the node in `scope`. + /// + /// ## Panics + /// Panics if the node doesn't belong to `file` or is outside `scope`. + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> Self::Id; + + /// Looks up the AST node by its ID. + /// + /// ## Panics + /// May panic if the `id` does not belong to the AST of `file`, or is outside `scope`. + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self + where + Self: Sized; +} + +/// Extension trait for AST nodes that can be resolved by an `AstId`. +pub trait AstIdNode { + type ScopeId; + + /// Resolves the AST id of the node. + /// + /// ## Panics + /// May panic if the node does not belongs to `file`'s AST or is outside of `scope`. It may also + /// return an incorrect node if that's the case. + + fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId; + + /// Resolves the AST node for `id`. + /// + /// ## Panics + /// May panic if the `id` does not belong to the AST of `file` or it returns an incorrect node. + + fn lookup(db: &dyn Db, file: VfsFile, id: AstId) -> &Self + where + Self: Sized; +} + +impl AstIdNode for T +where + T: ScopeAstIdNode, +{ + type ScopeId = T::Id; + + fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId { + let in_scope_id = self.scope_ast_id(db, file, scope); + AstId { scope, in_scope_id } + } + + fn lookup(db: &dyn Db, file: VfsFile, id: AstId) -> &Self + where + Self: Sized, + { + let scope = id.scope; + Self::lookup_in_scope(db, file, scope, id.in_scope_id) + } +} + +/// Uniquely identifies an AST node in a file. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct AstId { + /// The node's scope. + scope: FileScopeId, + + /// The ID of the node inside [`Self::scope`]. + in_scope_id: L, +} + +/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`]. +#[newtype_index] +pub struct ScopeExpressionId; + +impl ScopeAstIdNode for ast::Expr { + type Id = ScopeExpressionId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ast_ids.expressions_map[&NodeKey::from_node(self)] + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ast_ids.expressions[id].node() + } +} + +/// Uniquely identifies an [`ast::Stmt`] in a [`FileScopeId`]. +#[newtype_index] +pub struct ScopeStatementId; + +impl ScopeAstIdNode for ast::Stmt { + type Id = ScopeStatementId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ast_ids.statement_id(self) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + + ast_ids.statements[id].node() + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub struct ScopeFunctionId(pub(super) ScopeStatementId); + +impl ScopeAstIdNode for ast::StmtFunctionDef { + type Id = ScopeFunctionId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ScopeFunctionId(ast_ids.statement_id(self)) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { + ast::Stmt::lookup_in_scope(db, file, scope, id.0) + .as_function_def_stmt() + .unwrap() + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub struct ScopeClassId(pub(super) ScopeStatementId); + +impl ScopeAstIdNode for ast::StmtClassDef { + type Id = ScopeClassId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ScopeClassId(ast_ids.statement_id(self)) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { + let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); + statement.as_class_def_stmt().unwrap() + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub struct ScopeAssignmentId(pub(super) ScopeStatementId); + +impl ScopeAstIdNode for ast::StmtAssign { + type Id = ScopeAssignmentId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ScopeAssignmentId(ast_ids.statement_id(self)) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { + let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); + statement.as_assign_stmt().unwrap() + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub struct ScopeAnnotatedAssignmentId(ScopeStatementId); + +impl ScopeAstIdNode for ast::StmtAnnAssign { + type Id = ScopeAnnotatedAssignmentId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ScopeAnnotatedAssignmentId(ast_ids.statement_id(self)) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { + let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); + statement.as_ann_assign_stmt().unwrap() + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub struct ScopeImportId(pub(super) ScopeStatementId); + +impl ScopeAstIdNode for ast::StmtImport { + type Id = ScopeImportId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ScopeImportId(ast_ids.statement_id(self)) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { + let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); + statement.as_import_stmt().unwrap() + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub struct ScopeImportFromId(pub(super) ScopeStatementId); + +impl ScopeAstIdNode for ast::StmtImportFrom { + type Id = ScopeImportFromId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ScopeImportFromId(ast_ids.statement_id(self)) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { + let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); + statement.as_import_from_stmt().unwrap() + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] +pub struct ScopeNamedExprId(pub(super) ScopeExpressionId); + +impl ScopeAstIdNode for ast::ExprNamed { + type Id = ScopeNamedExprId; + + fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ScopeNamedExprId(ast_ids.expression_id(self)) + } + + fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self + where + Self: Sized, + { + let expression = ast::Expr::lookup_in_scope(db, file, scope, id.0); + expression.as_named_expr().unwrap() + } +} + +#[derive(Debug)] +pub(super) struct AstIdsBuilder { + expressions: IndexVec>, + expressions_map: FxHashMap, + statements: IndexVec>, + statements_map: FxHashMap, +} + +impl AstIdsBuilder { + pub(super) fn new() -> Self { + Self { + expressions: IndexVec::default(), + expressions_map: FxHashMap::default(), + statements: IndexVec::default(), + statements_map: FxHashMap::default(), + } + } + + /// Adds `stmt` to the AST ids map and returns its id. + /// + /// ## Safety + /// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires + /// that `stmt` is a child of `parsed`. + #[allow(unsafe_code)] + pub(super) unsafe fn record_statement( + &mut self, + stmt: &ast::Stmt, + parsed: &ParsedModule, + ) -> ScopeStatementId { + let statement_id = self.statements.push(AstNodeRef::new(parsed.clone(), stmt)); + + self.statements_map + .insert(NodeKey::from_node(stmt), statement_id); + + statement_id + } + + /// Adds `expr` to the AST ids map and returns its id. + /// + /// ## Safety + /// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires + /// that `expr` is a child of `parsed`. + #[allow(unsafe_code)] + pub(super) unsafe fn record_expression( + &mut self, + expr: &ast::Expr, + parsed: &ParsedModule, + ) -> ScopeExpressionId { + let expression_id = self.expressions.push(AstNodeRef::new(parsed.clone(), expr)); + + self.expressions_map + .insert(NodeKey::from_node(expr), expression_id); + + expression_id + } + + pub(super) fn finish(mut self) -> AstIds { + self.expressions.shrink_to_fit(); + self.expressions_map.shrink_to_fit(); + self.statements.shrink_to_fit(); + self.statements_map.shrink_to_fit(); + + AstIds { + expressions: self.expressions, + expressions_map: self.expressions_map, + statements: self.statements, + statements_map: self.statements_map, + } + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs new file mode 100644 index 0000000000..ef1237b70e --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs @@ -0,0 +1,398 @@ +use std::sync::Arc; + +use rustc_hash::FxHashMap; + +use ruff_db::parsed::ParsedModule; +use ruff_index::IndexVec; +use ruff_python_ast as ast; +use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor}; + +use crate::name::Name; +use crate::red_knot::node_key::NodeKey; +use crate::red_knot::semantic_index::ast_ids::{ + AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId, + ScopeImportId, ScopeNamedExprId, +}; +use crate::red_knot::semantic_index::definition::{ + Definition, ImportDefinition, ImportFromDefinition, +}; +use crate::red_knot::semantic_index::symbol::{ + FileScopeId, FileSymbolId, Scope, ScopeKind, ScopeSymbolId, SymbolFlags, SymbolTableBuilder, +}; +use crate::red_knot::semantic_index::SemanticIndex; + +pub(super) struct SemanticIndexBuilder<'a> { + // Builder state + module: &'a ParsedModule, + scope_stack: Vec, + /// the definition whose target(s) we are currently walking + current_definition: Option, + + // Semantic Index fields + scopes: IndexVec, + symbol_tables: IndexVec, + ast_ids: IndexVec, + expression_scopes: FxHashMap, +} + +impl<'a> SemanticIndexBuilder<'a> { + pub(super) fn new(parsed: &'a ParsedModule) -> Self { + let mut builder = Self { + module: parsed, + scope_stack: Vec::new(), + current_definition: None, + + scopes: IndexVec::new(), + symbol_tables: IndexVec::new(), + ast_ids: IndexVec::new(), + expression_scopes: FxHashMap::default(), + }; + + builder.push_scope_with_parent( + ScopeKind::Module, + &Name::new_static(""), + None, + None, + None, + ); + + builder + } + + fn current_scope(&self) -> FileScopeId { + *self + .scope_stack + .last() + .expect("Always to have a root scope") + } + + fn push_scope( + &mut self, + scope_kind: ScopeKind, + name: &Name, + defining_symbol: Option, + definition: Option, + ) { + let parent = self.current_scope(); + self.push_scope_with_parent(scope_kind, name, defining_symbol, definition, Some(parent)); + } + + fn push_scope_with_parent( + &mut self, + scope_kind: ScopeKind, + name: &Name, + defining_symbol: Option, + definition: Option, + parent: Option, + ) { + let children_start = self.scopes.next_index() + 1; + + let scope = Scope { + name: name.clone(), + parent, + defining_symbol, + definition, + kind: scope_kind, + descendents: children_start..children_start, + }; + + let scope_id = self.scopes.push(scope); + self.symbol_tables.push(SymbolTableBuilder::new()); + self.ast_ids.push(AstIdsBuilder::new()); + self.scope_stack.push(scope_id); + } + + fn pop_scope(&mut self) -> FileScopeId { + let id = self.scope_stack.pop().expect("Root scope to be present"); + let children_end = self.scopes.next_index(); + let scope = &mut self.scopes[id]; + scope.descendents = scope.descendents.start..children_end; + id + } + + fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder { + let scope_id = self.current_scope(); + &mut self.symbol_tables[scope_id] + } + + fn current_ast_ids(&mut self) -> &mut AstIdsBuilder { + let scope_id = self.current_scope(); + &mut self.ast_ids[scope_id] + } + + fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopeSymbolId { + let scope = self.current_scope(); + let symbol_table = self.current_symbol_table(); + + symbol_table.add_or_update_symbol(name, scope, flags, None) + } + + fn add_or_update_symbol_with_definition( + &mut self, + name: Name, + + definition: Definition, + ) -> ScopeSymbolId { + let scope = self.current_scope(); + let symbol_table = self.current_symbol_table(); + + symbol_table.add_or_update_symbol(name, scope, SymbolFlags::IS_DEFINED, Some(definition)) + } + + fn with_type_params( + &mut self, + name: &Name, + params: &Option>, + definition: Option, + defining_symbol: FileSymbolId, + nested: impl FnOnce(&mut Self) -> FileScopeId, + ) -> FileScopeId { + if let Some(type_params) = params { + self.push_scope( + ScopeKind::Annotation, + name, + Some(defining_symbol), + definition, + ); + for type_param in &type_params.type_params { + let name = match type_param { + ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, + ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name, + ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name, + }; + self.add_or_update_symbol(Name::new(name), SymbolFlags::IS_DEFINED); + } + } + let nested_scope = nested(self); + + if params.is_some() { + self.pop_scope(); + } + + nested_scope + } + + pub(super) fn build(mut self) -> SemanticIndex { + let module = self.module; + self.visit_body(module.suite()); + + // Pop the root scope + self.pop_scope(); + assert!(self.scope_stack.is_empty()); + + assert!(self.current_definition.is_none()); + + let mut symbol_tables: IndexVec<_, _> = self + .symbol_tables + .into_iter() + .map(|builder| Arc::new(builder.finish())) + .collect(); + + let mut ast_ids: IndexVec<_, _> = self + .ast_ids + .into_iter() + .map(super::ast_ids::AstIdsBuilder::finish) + .collect(); + + self.scopes.shrink_to_fit(); + ast_ids.shrink_to_fit(); + symbol_tables.shrink_to_fit(); + self.expression_scopes.shrink_to_fit(); + + SemanticIndex { + symbol_tables, + scopes: self.scopes, + ast_ids, + expression_scopes: self.expression_scopes, + } + } +} + +impl Visitor<'_> for SemanticIndexBuilder<'_> { + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + let module = self.module; + #[allow(unsafe_code)] + let statement_id = unsafe { + // SAFETY: The builder only visits nodes that are part of `module`. This guarantees that + // the current statement must be a child of `module`. + self.current_ast_ids().record_statement(stmt, module) + }; + match stmt { + ast::Stmt::FunctionDef(function_def) => { + for decorator in &function_def.decorator_list { + self.visit_decorator(decorator); + } + let name = Name::new(&function_def.name.id); + let definition = Definition::FunctionDef(ScopeFunctionId(statement_id)); + let scope = self.current_scope(); + let symbol = FileSymbolId::new( + scope, + self.add_or_update_symbol_with_definition(name.clone(), definition), + ); + + self.with_type_params( + &name, + &function_def.type_params, + Some(definition), + symbol, + |builder| { + builder.visit_parameters(&function_def.parameters); + for expr in &function_def.returns { + builder.visit_annotation(expr); + } + + builder.push_scope( + ScopeKind::Function, + &name, + Some(symbol), + Some(definition), + ); + builder.visit_body(&function_def.body); + builder.pop_scope() + }, + ); + } + ast::Stmt::ClassDef(class) => { + for decorator in &class.decorator_list { + self.visit_decorator(decorator); + } + + let name = Name::new(&class.name.id); + let definition = Definition::from(ScopeClassId(statement_id)); + let id = FileSymbolId::new( + self.current_scope(), + self.add_or_update_symbol_with_definition(name.clone(), definition), + ); + self.with_type_params(&name, &class.type_params, Some(definition), id, |builder| { + if let Some(arguments) = &class.arguments { + builder.visit_arguments(arguments); + } + + builder.push_scope(ScopeKind::Class, &name, Some(id), Some(definition)); + builder.visit_body(&class.body); + + builder.pop_scope() + }); + } + ast::Stmt::Import(ast::StmtImport { names, .. }) => { + for (i, alias) in names.iter().enumerate() { + let symbol_name = if let Some(asname) = &alias.asname { + asname.id.as_str() + } else { + alias.name.id.split('.').next().unwrap() + }; + + let def = Definition::Import(ImportDefinition { + import_id: ScopeImportId(statement_id), + alias: u32::try_from(i).unwrap(), + }); + self.add_or_update_symbol_with_definition(Name::new(symbol_name), def); + } + } + ast::Stmt::ImportFrom(ast::StmtImportFrom { + module: _, + names, + level: _, + .. + }) => { + for (i, alias) in names.iter().enumerate() { + let symbol_name = if let Some(asname) = &alias.asname { + asname.id.as_str() + } else { + alias.name.id.as_str() + }; + let def = Definition::ImportFrom(ImportFromDefinition { + import_id: ScopeImportFromId(statement_id), + name: u32::try_from(i).unwrap(), + }); + self.add_or_update_symbol_with_definition(Name::new(symbol_name), def); + } + } + ast::Stmt::Assign(node) => { + debug_assert!(self.current_definition.is_none()); + self.visit_expr(&node.value); + self.current_definition = + Some(Definition::Assignment(ScopeAssignmentId(statement_id))); + for target in &node.targets { + self.visit_expr(target); + } + self.current_definition = None; + } + _ => { + walk_stmt(self, stmt); + } + } + } + + fn visit_expr(&mut self, expr: &'_ ast::Expr) { + let module = self.module; + #[allow(unsafe_code)] + let expression_id = unsafe { + // SAFETY: The builder only visits nodes that are part of `module`. This guarantees that + // the current expression must be a child of `module`. + self.current_ast_ids().record_expression(expr, module) + }; + + self.expression_scopes + .insert(NodeKey::from_node(expr), self.current_scope()); + + match expr { + ast::Expr::Name(ast::ExprName { id, ctx, .. }) => { + let flags = match ctx { + ast::ExprContext::Load => SymbolFlags::IS_USED, + ast::ExprContext::Store => SymbolFlags::IS_DEFINED, + ast::ExprContext::Del => SymbolFlags::IS_DEFINED, + ast::ExprContext::Invalid => SymbolFlags::empty(), + }; + match self.current_definition { + Some(definition) if flags.contains(SymbolFlags::IS_DEFINED) => { + self.add_or_update_symbol_with_definition(Name::new(id), definition); + } + _ => { + self.add_or_update_symbol(Name::new(id), flags); + } + } + + walk_expr(self, expr); + } + ast::Expr::Named(node) => { + debug_assert!(self.current_definition.is_none()); + self.current_definition = + Some(Definition::NamedExpr(ScopeNamedExprId(expression_id))); + // TODO walrus in comprehensions is implicitly nonlocal + self.visit_expr(&node.target); + self.current_definition = None; + self.visit_expr(&node.value); + } + ast::Expr::If(ast::ExprIf { + body, test, orelse, .. + }) => { + // TODO detect statically known truthy or falsy test (via type inference, not naive + // AST inspection, so we can't simplify here, need to record test expression in CFG + // for later checking) + + self.visit_expr(test); + + // let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node()); + + // self.set_current_flow_node(if_branch); + // self.insert_constraint(test); + self.visit_expr(body); + + // let post_body = self.current_flow_node(); + + // self.set_current_flow_node(if_branch); + self.visit_expr(orelse); + + // let post_else = self + // .flow_graph_builder + // .add_phi(self.current_flow_node(), post_body); + + // self.set_current_flow_node(post_else); + } + _ => { + walk_expr(self, expr); + } + } + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs new file mode 100644 index 0000000000..97170b9e27 --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs @@ -0,0 +1,76 @@ +use crate::red_knot::semantic_index::ast_ids::{ + ScopeAnnotatedAssignmentId, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, + ScopeImportFromId, ScopeImportId, ScopeNamedExprId, +}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Definition { + Import(ImportDefinition), + ImportFrom(ImportFromDefinition), + ClassDef(ScopeClassId), + FunctionDef(ScopeFunctionId), + Assignment(ScopeAssignmentId), + AnnotatedAssignment(ScopeAnnotatedAssignmentId), + NamedExpr(ScopeNamedExprId), + /// represents the implicit initial definition of every name as "unbound" + Unbound, + // TODO with statements, except handlers, function args... +} + +impl From for Definition { + fn from(value: ImportDefinition) -> Self { + Self::Import(value) + } +} + +impl From for Definition { + fn from(value: ImportFromDefinition) -> Self { + Self::ImportFrom(value) + } +} + +impl From for Definition { + fn from(value: ScopeClassId) -> Self { + Self::ClassDef(value) + } +} + +impl From for Definition { + fn from(value: ScopeFunctionId) -> Self { + Self::FunctionDef(value) + } +} + +impl From for Definition { + fn from(value: ScopeAssignmentId) -> Self { + Self::Assignment(value) + } +} + +impl From for Definition { + fn from(value: ScopeAnnotatedAssignmentId) -> Self { + Self::AnnotatedAssignment(value) + } +} + +impl From for Definition { + fn from(value: ScopeNamedExprId) -> Self { + Self::NamedExpr(value) + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct ImportDefinition { + pub(super) import_id: ScopeImportId, + + /// Index into [`ruff_python_ast::StmtImport::names`]. + pub(super) alias: u32, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct ImportFromDefinition { + pub(super) import_id: ScopeImportFromId, + + /// Index into [`ruff_python_ast::StmtImportFrom::names`]. + pub(super) name: u32, +} diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs new file mode 100644 index 0000000000..b543742d3f --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs @@ -0,0 +1,362 @@ +// Allow unused underscore violations generated by the salsa macro +// TODO(micha): Contribute fix upstream +#![allow(clippy::used_underscore_binding)] + +use std::hash::{Hash, Hasher}; +use std::ops::Range; + +use bitflags::bitflags; +use hashbrown::hash_map::RawEntryMut; +use rustc_hash::FxHasher; +use smallvec::SmallVec; + +use ruff_db::vfs::VfsFile; +use ruff_index::{newtype_index, IndexVec}; + +use crate::name::Name; +use crate::red_knot::semantic_index::definition::Definition; +use crate::red_knot::semantic_index::{scopes_map, symbol_table, SymbolMap}; +use crate::Db; + +#[derive(Eq, PartialEq, Debug)] +pub struct Symbol { + name: Name, + flags: SymbolFlags, + scope: FileScopeId, + + /// The nodes that define this symbol, in source order. + definitions: SmallVec<[Definition; 4]>, +} + +impl Symbol { + fn new(name: Name, scope: FileScopeId, definition: Option) -> Self { + Self { + name, + scope, + flags: SymbolFlags::empty(), + definitions: definition.into_iter().collect(), + } + } + + fn push_definition(&mut self, definition: Definition) { + self.definitions.push(definition); + } + + fn insert_flags(&mut self, flags: SymbolFlags) { + self.flags.insert(flags); + } + + /// The symbol's name. + pub fn name(&self) -> &Name { + &self.name + } + + /// The scope in which this symbol is defined. + pub fn scope(&self) -> FileScopeId { + self.scope + } + + /// Is the symbol used in its containing scope? + pub fn is_used(&self) -> bool { + self.flags.contains(SymbolFlags::IS_USED) + } + + /// Is the symbol defined in its containing scope? + pub fn is_defined(&self) -> bool { + self.flags.contains(SymbolFlags::IS_DEFINED) + } + + pub fn definitions(&self) -> &[Definition] { + &self.definitions + } +} + +bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(super) struct SymbolFlags: u8 { + const IS_USED = 1 << 0; + const IS_DEFINED = 1 << 1; + /// TODO: This flag is not yet set by anything + const MARKED_GLOBAL = 1 << 2; + /// TODO: This flag is not yet set by anything + const MARKED_NONLOCAL = 1 << 3; + } +} + +/// ID that uniquely identifies a public symbol defined in a module's root scope. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct PublicSymbolId { + scope: ScopeId, + symbol: ScopeSymbolId, +} + +impl PublicSymbolId { + pub(crate) fn new(scope: ScopeId, symbol: ScopeSymbolId) -> Self { + Self { scope, symbol } + } + + pub fn scope(self) -> ScopeId { + self.scope + } + + pub(crate) fn scope_symbol(self) -> ScopeSymbolId { + self.symbol + } +} + +impl From for ScopeSymbolId { + fn from(val: PublicSymbolId) -> Self { + val.scope_symbol() + } +} + +/// ID that uniquely identifies a symbol in a file. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct FileSymbolId { + scope: FileScopeId, + symbol: ScopeSymbolId, +} + +impl FileSymbolId { + pub(super) fn new(scope: FileScopeId, symbol: ScopeSymbolId) -> Self { + Self { scope, symbol } + } + + pub fn scope(self) -> FileScopeId { + self.scope + } + + pub(crate) fn symbol(self) -> ScopeSymbolId { + self.symbol + } +} + +impl From for ScopeSymbolId { + fn from(val: FileSymbolId) -> Self { + val.symbol() + } +} + +/// Symbol ID that uniquely identifies a symbol inside a [`Scope`]. +#[newtype_index] +pub(crate) struct ScopeSymbolId; + +/// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter. +/// +/// The [`SemanticIndex`] uses [`FileScopeId`] on a per-file level to identify scopes +/// because they allow for more efficient storage of associated data +/// (use of an [`IndexVec`] keyed by [`FileScopeId`] over an [`FxHashMap`] keyed by [`ScopeId`]). +#[derive(Eq, PartialEq, Debug)] +pub(crate) struct ScopesMap { + scopes: IndexVec, +} + +impl ScopesMap { + pub(super) fn new(scopes: IndexVec) -> Self { + Self { scopes } + } + + /// Gets the program-wide unique scope id for the given file specific `scope_id`. + fn get(&self, scope_id: FileScopeId) -> ScopeId { + self.scopes[scope_id] + } +} + +/// A cross-module identifier of a scope that can be used as a salsa query parameter. +#[salsa::tracked] +pub struct ScopeId { + #[allow(clippy::used_underscore_binding)] + pub file: VfsFile, + pub scope_id: FileScopeId, +} + +impl ScopeId { + /// Resolves the symbol named `name` in this scope. + pub fn symbol(self, db: &dyn Db, name: &str) -> Option { + let symbol_table = symbol_table(db, self); + let in_scope_id = symbol_table.symbol_id_by_name(name)?; + + Some(PublicSymbolId::new(self, in_scope_id)) + } +} + +/// ID that uniquely identifies a scope inside of a module. +#[newtype_index] +pub struct FileScopeId; + +impl FileScopeId { + /// Returns the scope id of the Root scope. + pub fn root() -> Self { + FileScopeId::from_u32(0) + } + + pub fn to_scope_id(self, db: &dyn Db, file: VfsFile) -> ScopeId { + scopes_map(db, file).get(self) + } +} + +#[derive(Debug, Eq, PartialEq)] +pub struct Scope { + pub(super) name: Name, + pub(super) parent: Option, + pub(super) definition: Option, + pub(super) defining_symbol: Option, + pub(super) kind: ScopeKind, + pub(super) descendents: Range, +} + +impl Scope { + pub fn name(&self) -> &Name { + &self.name + } + + pub fn definition(&self) -> Option { + self.definition + } + + pub fn defining_symbol(&self) -> Option { + self.defining_symbol + } + + pub fn parent(self) -> Option { + self.parent + } + + pub fn kind(&self) -> ScopeKind { + self.kind + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ScopeKind { + Module, + Annotation, + Class, + Function, +} + +/// Symbol table for a specific [`Scope`]. +#[derive(Debug)] +pub struct SymbolTable { + /// The symbols in this scope. + symbols: IndexVec, + + /// The symbols indexed by name. + symbols_by_name: SymbolMap, +} + +impl SymbolTable { + fn new() -> Self { + Self { + symbols: IndexVec::new(), + symbols_by_name: SymbolMap::default(), + } + } + + fn shrink_to_fit(&mut self) { + self.symbols.shrink_to_fit(); + } + + pub(crate) fn symbol(&self, symbol_id: impl Into) -> &Symbol { + &self.symbols[symbol_id.into()] + } + + #[allow(unused)] + pub(crate) fn symbol_ids(&self) -> impl Iterator { + self.symbols.indices() + } + + pub fn symbols(&self) -> impl Iterator { + self.symbols.iter() + } + + /// Returns the symbol named `name`. + #[allow(unused)] + pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> { + let id = self.symbol_id_by_name(name)?; + Some(self.symbol(id)) + } + + /// Returns the [`ScopeSymbolId`] of the symbol named `name`. + pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option { + let (id, ()) = self + .symbols_by_name + .raw_entry() + .from_hash(Self::hash_name(name), |id| { + self.symbol(*id).name().as_str() == name + })?; + + Some(*id) + } + + fn hash_name(name: &str) -> u64 { + let mut hasher = FxHasher::default(); + name.hash(&mut hasher); + hasher.finish() + } +} + +impl PartialEq for SymbolTable { + fn eq(&self, other: &Self) -> bool { + // We don't need to compare the symbols_by_name because the name is already captured in `Symbol`. + self.symbols == other.symbols + } +} + +impl Eq for SymbolTable {} + +#[derive(Debug)] +pub(super) struct SymbolTableBuilder { + table: SymbolTable, +} + +impl SymbolTableBuilder { + pub(super) fn new() -> Self { + Self { + table: SymbolTable::new(), + } + } + + pub(super) fn add_or_update_symbol( + &mut self, + name: Name, + scope: FileScopeId, + flags: SymbolFlags, + definition: Option, + ) -> ScopeSymbolId { + let hash = SymbolTable::hash_name(&name); + let entry = self + .table + .symbols_by_name + .raw_entry_mut() + .from_hash(hash, |id| self.table.symbols[*id].name() == &name); + + match entry { + RawEntryMut::Occupied(entry) => { + let symbol = &mut self.table.symbols[*entry.key()]; + symbol.insert_flags(flags); + + if let Some(definition) = definition { + symbol.push_definition(definition); + } + + *entry.key() + } + RawEntryMut::Vacant(entry) => { + let mut symbol = Symbol::new(name, scope, definition); + symbol.insert_flags(flags); + + let id = self.table.symbols.push(symbol); + entry.insert_with_hasher(hash, id, (), |id| { + SymbolTable::hash_name(self.table.symbols[*id].name().as_str()) + }); + id + } + } + } + + pub(super) fn finish(mut self) -> SymbolTable { + self.table.shrink_to_fit(); + self.table + } +}