From 7cd065e4a24942d256a37430cbd73db3a0f45113 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sat, 27 Apr 2024 10:34:00 +0200 Subject: [PATCH] Kick off Red-knot (#10849) Co-authored-by: Carl Meyer Co-authored-by: Carl Meyer --- Cargo.lock | 135 +++ Cargo.toml | 4 + crates/red_knot/Cargo.toml | 45 + crates/red_knot/src/ast_ids.rs | 415 ++++++++ crates/red_knot/src/cache.rs | 152 +++ crates/red_knot/src/cancellation.rs | 65 ++ crates/red_knot/src/db.rs | 171 ++++ crates/red_knot/src/files.rs | 148 +++ crates/red_knot/src/hir.rs | 67 ++ crates/red_knot/src/hir/definition.rs | 556 +++++++++++ crates/red_knot/src/lib.rs | 83 ++ crates/red_knot/src/lint.rs | 124 +++ crates/red_knot/src/main.rs | 399 ++++++++ crates/red_knot/src/module.rs | 911 ++++++++++++++++++ crates/red_knot/src/parse.rs | 95 ++ crates/red_knot/src/program/mod.rs | 154 +++ crates/red_knot/src/source.rs | 98 ++ crates/red_knot/src/symbols.rs | 765 +++++++++++++++ crates/red_knot/src/types.rs | 519 ++++++++++ crates/red_knot/src/types/infer.rs | 141 +++ crates/red_knot/src/watch.rs | 78 ++ crates/ruff_cache/Cargo.toml | 2 +- crates/ruff_python_ast/src/node.rs | 1249 +++++++++++++++++++++++-- crates/ruff_python_ast/src/nodes.rs | 9 +- 24 files changed, 6282 insertions(+), 103 deletions(-) create mode 100644 crates/red_knot/Cargo.toml create mode 100644 crates/red_knot/src/ast_ids.rs create mode 100644 crates/red_knot/src/cache.rs create mode 100644 crates/red_knot/src/cancellation.rs create mode 100644 crates/red_knot/src/db.rs create mode 100644 crates/red_knot/src/files.rs create mode 100644 crates/red_knot/src/hir.rs create mode 100644 crates/red_knot/src/hir/definition.rs create mode 100644 crates/red_knot/src/lib.rs create mode 100644 crates/red_knot/src/lint.rs create mode 100644 crates/red_knot/src/main.rs create mode 100644 crates/red_knot/src/module.rs create mode 100644 crates/red_knot/src/parse.rs create mode 100644 crates/red_knot/src/program/mod.rs create mode 100644 crates/red_knot/src/source.rs create mode 100644 crates/red_knot/src/symbols.rs create mode 100644 crates/red_knot/src/types.rs create mode 100644 crates/red_knot/src/types/infer.rs create mode 100644 crates/red_knot/src/watch.rs diff --git a/Cargo.lock b/Cargo.lock index f6a14143e6..047802e7ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,6 +36,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -535,6 +541,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "ctrlc" +version = "3.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +dependencies = [ + "nix", + "windows-sys 0.52.0", +] + [[package]] name = "darling" version = "0.20.8" @@ -570,6 +586,19 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.3", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "diff" version = "0.1.13" @@ -812,6 +841,10 @@ name = "hashbrown" version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heck" @@ -1214,6 +1247,16 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.21" @@ -1441,6 +1484,29 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.48.5", +] + [[package]] name = "paste" version = "1.0.14" @@ -1732,6 +1798,37 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "red_knot" +version = "0.1.0" +dependencies = [ + "anyhow", + "bitflags 2.5.0", + "crossbeam-channel", + "ctrlc", + "dashmap", + "hashbrown 0.14.3", + "indexmap", + "log", + "notify", + "parking_lot", + "rayon", + "ruff_index", + "ruff_notebook", + "ruff_python_ast", + "ruff_python_parser", + "ruff_python_trivia", + "ruff_text_size", + "rustc-hash", + "smallvec", + "smol_str", + "tempfile", + "textwrap", + "tracing", + "tracing-subscriber", + "tracing-tree", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -2475,6 +2572,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "seahash" version = "4.1.0" @@ -2628,6 +2731,21 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smawk" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" + +[[package]] +name = "smol_str" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6845563ada680337a52d43bb0b29f396f2d911616f6573012645b9e3d048a49" +dependencies = [ + "serde", +] + [[package]] name = "spin" version = "0.9.8" @@ -2779,6 +2897,17 @@ dependencies = [ "test-case-core", ] +[[package]] +name = "textwrap" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" +dependencies = [ + "smawk", + "unicode-linebreak", + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.59" @@ -3034,6 +3163,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-linebreak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" + [[package]] name = "unicode-normalization" version = "0.1.23" diff --git a/Cargo.toml b/Cargo.toml index 2dff4ad192..95681c8b13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ console_log = { version = "1.0.0" } countme = { version = "3.0.1" } criterion = { version = "0.5.1", default-features = false } crossbeam-channel = { version = "0.5.12" } +dashmap = { version = "5.5.3" } dirs = { version = "5.0.0" } drop_bomb = { version = "0.1.5" } env_logger = { version = "0.11.0" } @@ -39,10 +40,12 @@ filetime = { version = "0.2.23" } fs-err = { version = "2.11.0" } glob = { version = "0.3.1" } globset = { version = "0.4.14" } +hashbrown = "0.14.3" hexf-parse = { version = "0.2.1" } ignore = { version = "0.4.22" } imara-diff = { version = "0.1.5" } imperative = { version = "1.0.4" } +indexmap = { version = "2.2.6" } indicatif = { version = "0.17.8" } indoc = { version = "2.0.4" } insta = { version = "1.35.1", feature = ["filters", "glob"] } @@ -68,6 +71,7 @@ once_cell = { version = "1.19.0" } path-absolutize = { version = "3.1.1" } path-slash = { version = "0.2.1" } pathdiff = { version = "0.2.1" } +parking_lot = "0.12.1" pep440_rs = { version = "0.6.0", features = ["serde"] } pretty_assertions = "1.3.0" proc-macro2 = { version = "1.0.79" } diff --git a/crates/red_knot/Cargo.toml b/crates/red_knot/Cargo.toml new file mode 100644 index 0000000000..7907c8340a --- /dev/null +++ b/crates/red_knot/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "red_knot" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +authors.workspace = true +license.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ruff_python_parser = { path = "../ruff_python_parser" } +ruff_python_ast = { path = "../ruff_python_ast" } +ruff_python_trivia = { path = "../ruff_python_trivia" } +ruff_text_size = { path = "../ruff_text_size" } +ruff_index = { path = "../ruff_index" } +ruff_notebook = { path = "../ruff_notebook" } + +anyhow = { workspace = true } +bitflags = { workspace = true } +ctrlc = "3.4.4" +crossbeam-channel = { workspace = true } +dashmap = { workspace = true } +hashbrown = { workspace = true } +indexmap = { workspace = true } +log = { workspace = true } +notify = { workspace = true } +parking_lot = { workspace = true } +rayon = { workspace = true } +rustc-hash = { workspace = true } +smallvec = { workspace = true } +smol_str = "0.2.1" +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +tracing-tree = { workspace = true } + +[dev-dependencies] +textwrap = "0.16.1" +tempfile = { workspace = true } + +[lints] +workspace = true diff --git a/crates/red_knot/src/ast_ids.rs b/crates/red_knot/src/ast_ids.rs new file mode 100644 index 0000000000..784e44b22d --- /dev/null +++ b/crates/red_knot/src/ast_ids.rs @@ -0,0 +1,415 @@ +use std::any::type_name; +use std::fmt::{Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::marker::PhantomData; + +use rustc_hash::FxHashMap; + +use ruff_index::{Idx, IndexVec}; +use ruff_python_ast::visitor::preorder; +use ruff_python_ast::visitor::preorder::{PreorderVisitor, TraversalSignal}; +use ruff_python_ast::{ + AnyNodeRef, AstNode, ExceptHandler, ExceptHandlerExceptHandler, Expr, MatchCase, ModModule, + NodeKind, Parameter, Stmt, StmtAnnAssign, StmtAssign, StmtAugAssign, StmtClassDef, + StmtFunctionDef, StmtGlobal, StmtImport, StmtImportFrom, StmtNonlocal, StmtTypeAlias, + TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, WithItem, +}; +use ruff_text_size::{Ranged, TextRange}; + +/// A type agnostic ID that uniquely identifies an AST node in a file. +#[ruff_index::newtype_index] +pub struct AstId; + +/// A typed ID that uniquely identifies an AST node in a file. +/// +/// This is different from [`AstId`] in that it is a combination of ID and the type of the node the ID identifies. +/// Typing the ID prevents mixing IDs of different node types and allows to restrict the API to only accept +/// nodes for which an ID has been created (not all AST nodes get an ID). +pub struct TypedAstId { + erased: AstId, + _marker: PhantomData N>, +} + +impl TypedAstId { + /// Upcasts this ID from a more specific node type to a more general node type. + pub fn upcast(self) -> TypedAstId + where + N: Into, + { + TypedAstId { + erased: self.erased, + _marker: PhantomData, + } + } +} + +impl Copy for TypedAstId {} +impl Clone for TypedAstId { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for TypedAstId { + fn eq(&self, other: &Self) -> bool { + self.erased == other.erased + } +} + +impl Eq for TypedAstId {} +impl Hash for TypedAstId { + fn hash(&self, state: &mut H) { + self.erased.hash(state); + } +} + +impl Debug for TypedAstId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("TypedAstId") + .field(&self.erased) + .field(&type_name::()) + .finish() + } +} + +pub struct AstIds { + ids: IndexVec, + reverse: FxHashMap, +} + +impl AstIds { + // TODO rust analyzer doesn't allocate an ID for every node. It only allocates ids for + // nodes with a corresponding HIR element, that is nodes that are definitions. + pub fn from_module(module: &ModModule) -> Self { + let mut visitor = AstIdsVisitor::default(); + + // TODO: visit_module? + // Make sure we visit the root + visitor.create_id(module); + visitor.visit_body(&module.body); + + while let Some(deferred) = visitor.deferred.pop() { + match deferred { + DeferredNode::FunctionDefinition(def) => { + def.visit_preorder(&mut visitor); + } + DeferredNode::ClassDefinition(def) => def.visit_preorder(&mut visitor), + } + } + + AstIds { + ids: visitor.ids, + reverse: visitor.reverse, + } + } + + /// Returns the ID to the root node. + pub fn root(&self) -> NodeKey { + self.ids[AstId::new(0)] + } + + /// Returns the [`TypedAstId`] for a node. + pub fn ast_id(&self, node: &N) -> TypedAstId { + let key = node.syntax_node_key(); + TypedAstId { + erased: self.reverse.get(&key).copied().unwrap(), + _marker: PhantomData, + } + } + + /// Returns the [`TypedAstId`] for the node identified with the given [`TypedNodeKey`]. + pub fn ast_id_for_key(&self, node: &TypedNodeKey) -> TypedAstId { + let ast_id = self.ast_id_for_node_key(node.inner); + + TypedAstId { + erased: ast_id, + _marker: PhantomData, + } + } + + /// Returns the untyped [`AstId`] for the node identified by the given `node` key. + pub fn ast_id_for_node_key(&self, node: NodeKey) -> AstId { + self.reverse + .get(&node) + .copied() + .expect("Can't find node in AstIds map.") + } + + /// Returns the [`TypedNodeKey`] for the node identified by the given [`TypedAstId`]. + pub fn key(&self, id: TypedAstId) -> TypedNodeKey { + let syntax_key = self.ids[id.erased]; + + TypedNodeKey::new(syntax_key).unwrap() + } + + pub fn node_key(&self, id: TypedAstId) -> NodeKey { + self.ids[id.erased] + } +} + +impl std::fmt::Debug for AstIds { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut map = f.debug_map(); + for (key, value) in self.ids.iter_enumerated() { + map.entry(&key, &value); + } + + map.finish() + } +} + +impl PartialEq for AstIds { + fn eq(&self, other: &Self) -> bool { + self.ids == other.ids + } +} + +impl Eq for AstIds {} + +#[derive(Default)] +struct AstIdsVisitor<'a> { + ids: IndexVec, + reverse: FxHashMap, + deferred: Vec>, +} + +impl<'a> AstIdsVisitor<'a> { + fn create_id(&mut self, node: &A) { + let node_key = node.syntax_node_key(); + + let id = self.ids.push(node_key); + self.reverse.insert(node_key, id); + } +} + +impl<'a> PreorderVisitor<'a> for AstIdsVisitor<'a> { + fn visit_stmt(&mut self, stmt: &'a Stmt) { + match stmt { + Stmt::FunctionDef(def) => { + self.create_id(def); + self.deferred.push(DeferredNode::FunctionDefinition(def)); + return; + } + // TODO defer visiting the assignment body, type alias parameters etc? + Stmt::ClassDef(def) => { + self.create_id(def); + self.deferred.push(DeferredNode::ClassDefinition(def)); + return; + } + Stmt::Expr(_) => { + // Skip + return; + } + Stmt::Return(_) => {} + Stmt::Delete(_) => {} + Stmt::Assign(assignment) => self.create_id(assignment), + Stmt::AugAssign(assignment) => { + self.create_id(assignment); + } + Stmt::AnnAssign(assignment) => self.create_id(assignment), + Stmt::TypeAlias(assignment) => self.create_id(assignment), + Stmt::For(_) => {} + Stmt::While(_) => {} + Stmt::If(_) => {} + Stmt::With(_) => {} + Stmt::Match(_) => {} + Stmt::Raise(_) => {} + Stmt::Try(_) => {} + Stmt::Assert(_) => {} + Stmt::Import(import) => self.create_id(import), + Stmt::ImportFrom(import_from) => self.create_id(import_from), + Stmt::Global(global) => self.create_id(global), + Stmt::Nonlocal(non_local) => self.create_id(non_local), + Stmt::Pass(_) => {} + Stmt::Break(_) => {} + Stmt::Continue(_) => {} + Stmt::IpyEscapeCommand(_) => {} + } + + preorder::walk_stmt(self, stmt); + } + + fn visit_expr(&mut self, _expr: &'a Expr) {} + + fn visit_parameter(&mut self, parameter: &'a Parameter) { + self.create_id(parameter); + preorder::walk_parameter(self, parameter); + } + + fn visit_except_handler(&mut self, except_handler: &'a ExceptHandler) { + match except_handler { + ExceptHandler::ExceptHandler(except_handler) => { + self.create_id(except_handler); + } + } + + preorder::walk_except_handler(self, except_handler); + } + + fn visit_with_item(&mut self, with_item: &'a WithItem) { + self.create_id(with_item); + preorder::walk_with_item(self, with_item); + } + + fn visit_match_case(&mut self, match_case: &'a MatchCase) { + self.create_id(match_case); + preorder::walk_match_case(self, match_case); + } + + fn visit_type_param(&mut self, type_param: &'a TypeParam) { + self.create_id(type_param); + } +} + +enum DeferredNode<'a> { + FunctionDefinition(&'a StmtFunctionDef), + ClassDefinition(&'a StmtClassDef), +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct TypedNodeKey { + /// The type erased node key. + inner: NodeKey, + _marker: PhantomData N>, +} + +impl TypedNodeKey { + pub fn from_node(node: &N) -> Self { + let inner = NodeKey { + kind: node.as_any_node_ref().kind(), + range: node.range(), + }; + Self { + inner, + _marker: PhantomData, + } + } + + pub fn new(node_key: NodeKey) -> Option { + N::can_cast(node_key.kind).then_some(TypedNodeKey { + inner: node_key, + _marker: PhantomData, + }) + } + + pub fn resolve<'a>(&self, root: AnyNodeRef<'a>) -> Option> { + let node_ref = self.inner.resolve(root)?; + + Some(N::cast_ref(node_ref).unwrap()) + } + + pub fn resolve_unwrap<'a>(&self, root: AnyNodeRef<'a>) -> N::Ref<'a> { + self.resolve(root).expect("node should resolve") + } + + pub fn erased(&self) -> &NodeKey { + &self.inner + } +} + +struct FindNodeKeyVisitor<'a> { + key: NodeKey, + result: Option>, +} + +impl<'a> PreorderVisitor<'a> for FindNodeKeyVisitor<'a> { + fn enter_node(&mut self, node: AnyNodeRef<'a>) -> TraversalSignal { + if self.result.is_some() { + return TraversalSignal::Skip; + } + + if node.range() == self.key.range && node.kind() == self.key.kind { + self.result = Some(node); + TraversalSignal::Skip + } else if node.range().contains_range(self.key.range) { + TraversalSignal::Traverse + } else { + TraversalSignal::Skip + } + } + + fn visit_body(&mut self, body: &'a [Stmt]) { + // TODO it would be more efficient to use binary search instead of linear + for stmt in body { + if stmt.range().start() > self.key.range.end() { + break; + } + + self.visit_stmt(stmt); + } + } +} + +// TODO an alternative to this is to have a `NodeId` on each node (in increasing order depending on the position). +// This would allow to reduce the size of this to a u32. +// What would be nice if we could use an `Arc::weak_ref` here but that only works if we use +// `Arc` internally +// TODO: Implement the logic to resolve a node, given a db (and the correct file). +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct NodeKey { + kind: NodeKind, + range: TextRange, +} + +impl NodeKey { + pub fn resolve<'a>(&self, root: AnyNodeRef<'a>) -> Option> { + // We need to do a binary search here. Only traverse into a node if the range is withint the node + let mut visitor = FindNodeKeyVisitor { + key: *self, + result: None, + }; + + if visitor.enter_node(root) == TraversalSignal::Traverse { + root.visit_preorder(&mut visitor); + } + + visitor.result + } +} + +/// Marker trait implemented by AST nodes for which we extract the `AstId`. +pub trait HasAstId: AstNode { + fn node_key(&self) -> TypedNodeKey + where + Self: Sized, + { + TypedNodeKey { + inner: self.syntax_node_key(), + _marker: PhantomData, + } + } + + fn syntax_node_key(&self) -> NodeKey { + NodeKey { + kind: self.as_any_node_ref().kind(), + range: self.range(), + } + } +} + +impl HasAstId for StmtFunctionDef {} +impl HasAstId for StmtClassDef {} +impl HasAstId for StmtAnnAssign {} +impl HasAstId for StmtAugAssign {} +impl HasAstId for StmtAssign {} +impl HasAstId for StmtTypeAlias {} + +impl HasAstId for ModModule {} + +impl HasAstId for StmtImport {} + +impl HasAstId for StmtImportFrom {} + +impl HasAstId for Parameter {} + +impl HasAstId for TypeParam {} +impl HasAstId for Stmt {} +impl HasAstId for TypeParamTypeVar {} +impl HasAstId for TypeParamTypeVarTuple {} +impl HasAstId for TypeParamParamSpec {} +impl HasAstId for StmtGlobal {} +impl HasAstId for StmtNonlocal {} + +impl HasAstId for ExceptHandlerExceptHandler {} +impl HasAstId for WithItem {} +impl HasAstId for MatchCase {} diff --git a/crates/red_knot/src/cache.rs b/crates/red_knot/src/cache.rs new file mode 100644 index 0000000000..ac1e891aca --- /dev/null +++ b/crates/red_knot/src/cache.rs @@ -0,0 +1,152 @@ +use std::fmt::Formatter; +use std::hash::Hash; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use dashmap::mapref::entry::Entry; + +use crate::FxDashMap; + +/// Simple key value cache that locks on a per-key level. +pub struct KeyValueCache { + map: FxDashMap, + statistics: CacheStatistics, +} + +impl KeyValueCache +where + K: Eq + Hash + Clone, + V: Clone, +{ + pub fn try_get(&self, key: &K) -> Option { + if let Some(existing) = self.map.get(key) { + self.statistics.hit(); + Some(existing.clone()) + } else { + self.statistics.miss(); + None + } + } + + pub fn get(&self, key: &K, compute: F) -> V + where + F: FnOnce(&K) -> V, + { + match self.map.entry(key.clone()) { + Entry::Occupied(cached) => { + self.statistics.hit(); + + cached.get().clone() + } + Entry::Vacant(vacant) => { + self.statistics.miss(); + + let value = compute(key); + vacant.insert(value.clone()); + value + } + } + } + + pub fn set(&mut self, key: K, value: V) { + self.map.insert(key, value); + } + + pub fn remove(&mut self, key: &K) -> Option { + self.map.remove(key).map(|(_, value)| value) + } + + pub fn clear(&mut self) { + self.map.clear(); + self.map.shrink_to_fit(); + } + + pub fn statistics(&self) -> Option { + self.statistics.to_statistics() + } +} + +impl Default for KeyValueCache +where + K: Eq + Hash, + V: Clone, +{ + fn default() -> Self { + Self { + map: FxDashMap::default(), + statistics: CacheStatistics::default(), + } + } +} + +impl std::fmt::Debug for KeyValueCache +where + K: std::fmt::Debug + Eq + Hash, + V: std::fmt::Debug, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut debug = f.debug_map(); + + for entry in &self.map { + debug.entry(&entry.value(), &entry.key()); + } + + debug.finish() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Statistics { + pub hits: usize, + pub misses: usize, +} + +impl Statistics { + #[allow(clippy::cast_precision_loss)] + pub fn hit_rate(&self) -> Option { + if self.hits + self.misses == 0 { + return None; + } + + Some((self.hits as f64) / (self.hits + self.misses) as f64) + } +} + +#[cfg(debug_assertions)] +pub type CacheStatistics = DebugStatistics; + +#[cfg(not(debug_assertions))] +pub type CacheStatistics = ReleaseStatistics; + +#[derive(Debug, Default)] +pub struct DebugStatistics { + hits: AtomicUsize, + misses: AtomicUsize, +} + +impl DebugStatistics { + // TODO figure out appropriate Ordering + pub fn hit(&self) { + self.hits.fetch_add(1, Ordering::SeqCst); + } + + pub fn miss(&self) { + self.misses.fetch_add(1, Ordering::SeqCst); + } + + pub fn to_statistics(&self) -> Option { + let hits = self.hits.load(Ordering::SeqCst); + let misses = self.misses.load(Ordering::SeqCst); + + Some(Statistics { hits, misses }) + } +} + +#[derive(Debug, Default)] +pub struct ReleaseStatistics; + +impl ReleaseStatistics { + #[inline] + pub fn to_statistics(&self) -> Option { + None + } +} diff --git a/crates/red_knot/src/cancellation.rs b/crates/red_knot/src/cancellation.rs new file mode 100644 index 0000000000..9f5e57cf56 --- /dev/null +++ b/crates/red_knot/src/cancellation.rs @@ -0,0 +1,65 @@ +use std::sync::{Arc, Condvar, Mutex}; + +#[derive(Debug, Default)] +pub struct CancellationSource { + signal: Arc<(Mutex, Condvar)>, +} + +impl CancellationSource { + pub fn new() -> Self { + Self { + signal: Arc::new((Mutex::new(false), Condvar::default())), + } + } + + pub fn cancel(&self) { + let (cancelled, condvar) = &*self.signal; + + let mut cancelled = cancelled.lock().unwrap(); + + if *cancelled { + return; + } + + *cancelled = true; + condvar.notify_all(); + } + + pub fn is_cancelled(&self) -> bool { + let (cancelled, _) = &*self.signal; + + *cancelled.lock().unwrap() + } + + pub fn token(&self) -> CancellationToken { + CancellationToken { + signal: self.signal.clone(), + } + } +} + +#[derive(Clone, Debug)] +pub struct CancellationToken { + signal: Arc<(Mutex, Condvar)>, +} + +impl CancellationToken { + /// Returns `true` if cancellation has been requested. + pub fn is_cancelled(&self) -> bool { + let (cancelled, _) = &*self.signal; + + *cancelled.lock().unwrap() + } + + pub fn wait(&self) { + let (bool, condvar) = &*self.signal; + + let lock = condvar + .wait_while(bool.lock().unwrap(), |bool| !*bool) + .unwrap(); + + debug_assert!(*lock); + + drop(lock); + } +} diff --git a/crates/red_knot/src/db.rs b/crates/red_knot/src/db.rs new file mode 100644 index 0000000000..c4b5b67394 --- /dev/null +++ b/crates/red_knot/src/db.rs @@ -0,0 +1,171 @@ +use std::path::Path; +use std::sync::Arc; + +use crate::files::FileId; +use crate::lint::{Diagnostics, LintSyntaxStorage}; +use crate::module::{Module, ModuleData, ModuleName, ModuleResolver, ModuleSearchPath}; +use crate::parse::{Parsed, ParsedStorage}; +use crate::source::{Source, SourceStorage}; +use crate::symbols::{SymbolId, SymbolTable, SymbolTablesStorage}; +use crate::types::{Type, TypeStore}; + +pub trait SourceDb { + // queries + fn file_id(&self, path: &std::path::Path) -> FileId; + + fn file_path(&self, file_id: FileId) -> Arc; + + fn source(&self, file_id: FileId) -> Source; + + fn parse(&self, file_id: FileId) -> Parsed; + + fn lint_syntax(&self, file_id: FileId) -> Diagnostics; +} + +pub trait SemanticDb: SourceDb { + // queries + fn resolve_module(&self, name: ModuleName) -> Option; + + fn symbol_table(&self, file_id: FileId) -> Arc; + + // mutations + fn path_to_module(&mut self, path: &Path) -> Option; + + fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)>; + + fn set_module_search_paths(&mut self, paths: Vec); + + fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type; +} + +pub trait Db: SemanticDb {} + +#[derive(Debug, Default)] +pub struct SourceJar { + pub sources: SourceStorage, + pub parsed: ParsedStorage, + pub lint_syntax: LintSyntaxStorage, +} + +#[derive(Debug, Default)] +pub struct SemanticJar { + pub module_resolver: ModuleResolver, + pub symbol_tables: SymbolTablesStorage, + pub type_store: TypeStore, +} + +/// Gives access to a specific jar in the database. +/// +/// Nope, the terminology isn't borrowed from Java but from Salsa , +/// which is an analogy to storing the salsa in different jars. +/// +/// The basic idea is that each crate can define its own jar and the jars can be combined to a single +/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of +/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels. +/// +/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache). +/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide +/// to use a macro to generate the queries. +pub trait HasJar { + /// Gives a read-only reference to the jar. + fn jar(&self) -> &T; + + /// Gives a mutable reference to the jar. + fn jar_mut(&mut self) -> &mut T; +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::db::{HasJar, SourceDb, SourceJar}; + use crate::files::{FileId, Files}; + use crate::lint::{lint_syntax, Diagnostics}; + use crate::module::{ + add_module, path_to_module, resolve_module, set_module_search_paths, Module, ModuleData, + ModuleName, ModuleSearchPath, + }; + use crate::parse::{parse, Parsed}; + use crate::source::{source_text, Source}; + use crate::symbols::{symbol_table, SymbolId, SymbolTable}; + use crate::types::{infer_symbol_type, Type}; + use std::path::Path; + use std::sync::Arc; + + use super::{SemanticDb, SemanticJar}; + + // This can be a partial database used in a single crate for testing. + // It would hold fewer data than the full database. + #[derive(Debug, Default)] + pub(crate) struct TestDb { + files: Files, + source: SourceJar, + semantic: SemanticJar, + } + + impl HasJar for TestDb { + fn jar(&self) -> &SourceJar { + &self.source + } + + fn jar_mut(&mut self) -> &mut SourceJar { + &mut self.source + } + } + + impl HasJar for TestDb { + fn jar(&self) -> &SemanticJar { + &self.semantic + } + + fn jar_mut(&mut self) -> &mut SemanticJar { + &mut self.semantic + } + } + + impl SourceDb for TestDb { + fn file_id(&self, path: &Path) -> FileId { + self.files.intern(path) + } + + fn file_path(&self, file_id: FileId) -> Arc { + self.files.path(file_id) + } + + fn source(&self, file_id: FileId) -> Source { + source_text(self, file_id) + } + + fn parse(&self, file_id: FileId) -> Parsed { + parse(self, file_id) + } + + fn lint_syntax(&self, file_id: FileId) -> Diagnostics { + lint_syntax(self, file_id) + } + } + + impl SemanticDb for TestDb { + fn resolve_module(&self, name: ModuleName) -> Option { + resolve_module(self, name) + } + + fn symbol_table(&self, file_id: FileId) -> Arc { + symbol_table(self, file_id) + } + + fn path_to_module(&mut self, path: &Path) -> Option { + path_to_module(self, path) + } + + fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)> { + add_module(self, path) + } + + fn set_module_search_paths(&mut self, paths: Vec) { + set_module_search_paths(self, paths); + } + + fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type { + infer_symbol_type(self, file_id, symbol_id) + } + } +} diff --git a/crates/red_knot/src/files.rs b/crates/red_knot/src/files.rs new file mode 100644 index 0000000000..fc7f18115f --- /dev/null +++ b/crates/red_knot/src/files.rs @@ -0,0 +1,148 @@ +use std::fmt::{Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::path::Path; +use std::sync::Arc; + +use hashbrown::hash_map::RawEntryMut; +use parking_lot::RwLock; +use rustc_hash::FxHasher; + +use ruff_index::{newtype_index, IndexVec}; + +type Map = hashbrown::HashMap; + +#[newtype_index] +pub struct FileId; + +// TODO we'll need a higher level virtual file system abstraction that allows testing if a file exists +// or retrieving its content (ideally lazily and in a way that the memory can be retained later) +// I suspect that we'll end up with a FileSystem trait and our own Path abstraction. +#[derive(Clone, Default)] +pub struct Files { + inner: Arc>, +} + +impl Files { + #[tracing::instrument(level = "trace", skip(path))] + pub fn intern(&self, path: &Path) -> FileId { + self.inner.write().intern(path) + } + + pub fn try_get(&self, path: &Path) -> Option { + self.inner.read().try_get(path) + } + + // TODO Can we avoid using an `Arc` here? salsa can return references for some reason. + pub fn path(&self, id: FileId) -> Arc { + self.inner.read().path(id) + } +} + +impl Debug for Files { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let files = self.inner.read(); + let mut debug = f.debug_map(); + for item in files.iter() { + debug.entry(&item.0, &item.1); + } + + debug.finish() + } +} + +impl PartialEq for Files { + fn eq(&self, other: &Self) -> bool { + self.inner.read().eq(&other.inner.read()) + } +} + +impl Eq for Files {} + +#[derive(Default)] +struct FilesInner { + by_path: Map, + // TODO should we use a map here to reclaim the space for removed files? + // TODO I think we should use our own path abstraction here to avoid having to normalize paths + // and dealing with non-utf paths everywhere. + by_id: IndexVec>, +} + +impl FilesInner { + /// Inserts the path and returns a new id for it or returns the id if it is an existing path. + // TODO should this accept Path or PathBuf? + pub(crate) fn intern(&mut self, path: &Path) -> FileId { + let mut hasher = FxHasher::default(); + path.hash(&mut hasher); + let hash = hasher.finish(); + + let entry = self + .by_path + .raw_entry_mut() + .from_hash(hash, |existing_file| &*self.by_id[*existing_file] == path); + + match entry { + RawEntryMut::Occupied(entry) => *entry.key(), + RawEntryMut::Vacant(entry) => { + let id = self.by_id.push(Arc::from(path)); + entry.insert_with_hasher(hash, id, (), |_| hash); + id + } + } + } + + pub(crate) fn try_get(&self, path: &Path) -> Option { + let mut hasher = FxHasher::default(); + path.hash(&mut hasher); + let hash = hasher.finish(); + + Some( + *self + .by_path + .raw_entry() + .from_hash(hash, |existing_file| &*self.by_id[*existing_file] == path)? + .0, + ) + } + + /// Returns the path for the file with the given id. + pub(crate) fn path(&self, id: FileId) -> Arc { + self.by_id[id].clone() + } + + pub(crate) fn iter(&self) -> impl Iterator)> + '_ { + self.by_path.keys().map(|id| (*id, self.by_id[*id].clone())) + } +} + +impl PartialEq for FilesInner { + fn eq(&self, other: &Self) -> bool { + self.by_id == other.by_id + } +} + +impl Eq for FilesInner {} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn insert_path_twice_same_id() { + let files = Files::default(); + let path = PathBuf::from("foo/bar"); + let id1 = files.intern(&path); + let id2 = files.intern(&path); + assert_eq!(id1, id2); + } + + #[test] + fn insert_different_paths_different_ids() { + let files = Files::default(); + let path1 = PathBuf::from("foo/bar"); + let path2 = PathBuf::from("foo/bar/baz"); + let id1 = files.intern(&path1); + let id2 = files.intern(&path2); + assert_ne!(id1, id2); + } +} diff --git a/crates/red_knot/src/hir.rs b/crates/red_knot/src/hir.rs new file mode 100644 index 0000000000..030c7353d0 --- /dev/null +++ b/crates/red_knot/src/hir.rs @@ -0,0 +1,67 @@ +//! Key observations +//! +//! The HIR avoids allocations to large extends by: +//! * Using an arena per node type +//! * using ids and id ranges to reference items. +//! +//! Using separate arena per node type has the advantage that the IDs are relatively stable, because +//! they only change when a node of the same kind has been added or removed. (What's unclear is if that matters or if +//! it still triggers a re-compute because the AST-id in the node has changed). +//! +//! The HIR does not store all details. It mainly stores the *public* interface. There's a reference +//! back to the AST node to get more details. +//! +//! + +use crate::ast_ids::{HasAstId, TypedAstId}; +use crate::files::FileId; +use std::fmt::Formatter; +use std::hash::{Hash, Hasher}; + +pub struct HirAstId { + file_id: FileId, + node_id: TypedAstId, +} + +impl Copy for HirAstId {} +impl Clone for HirAstId { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for HirAstId { + fn eq(&self, other: &Self) -> bool { + self.file_id == other.file_id && self.node_id == other.node_id + } +} + +impl Eq for HirAstId {} + +impl std::fmt::Debug for HirAstId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HirAstId") + .field("file_id", &self.file_id) + .field("node_id", &self.node_id) + .finish() + } +} + +impl Hash for HirAstId { + fn hash(&self, state: &mut H) { + self.file_id.hash(state); + self.node_id.hash(state); + } +} + +impl HirAstId { + pub fn upcast(self) -> HirAstId + where + N: Into, + { + HirAstId { + file_id: self.file_id, + node_id: self.node_id.upcast(), + } + } +} diff --git a/crates/red_knot/src/hir/definition.rs b/crates/red_knot/src/hir/definition.rs new file mode 100644 index 0000000000..35b239796a --- /dev/null +++ b/crates/red_knot/src/hir/definition.rs @@ -0,0 +1,556 @@ +use std::ops::{Index, Range}; + +use ruff_index::{newtype_index, IndexVec}; +use ruff_python_ast::visitor::preorder; +use ruff_python_ast::visitor::preorder::PreorderVisitor; +use ruff_python_ast::{ + Decorator, ExceptHandler, ExceptHandlerExceptHandler, Expr, MatchCase, ModModule, Stmt, + StmtAnnAssign, StmtAssign, StmtClassDef, StmtFunctionDef, StmtGlobal, StmtImport, + StmtImportFrom, StmtNonlocal, StmtTypeAlias, TypeParam, TypeParamParamSpec, TypeParamTypeVar, + TypeParamTypeVarTuple, WithItem, +}; + +use crate::ast_ids::{AstIds, HasAstId}; +use crate::files::FileId; +use crate::hir::HirAstId; +use crate::Name; + +#[newtype_index] +pub struct FunctionId; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Function { + ast_id: HirAstId, + name: Name, + parameters: Range, + type_parameters: Range, // TODO: type_parameters, return expression, decorators +} + +#[newtype_index] +pub struct ParameterId; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Parameter { + kind: ParameterKind, + name: Name, + default: Option<()>, // TODO use expression HIR + ast_id: HirAstId, +} + +// TODO or should `Parameter` be an enum? +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum ParameterKind { + PositionalOnly, + Arguments, + Vararg, + KeywordOnly, + Kwarg, +} + +#[newtype_index] +pub struct ClassId; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Class { + name: Name, + ast_id: HirAstId, + // TODO type parameters, inheritance, decorators, members +} + +#[newtype_index] +pub struct AssignmentId; + +// This can have more than one name... +// but that means we can't implement `name()` on `ModuleItem`. + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Assignment { + // TODO: Handle multiple names / targets + name: Name, + ast_id: HirAstId, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct AnnotatedAssignment { + name: Name, + ast_id: HirAstId, +} + +#[newtype_index] +pub struct AnnotatedAssignmentId; + +#[newtype_index] +pub struct TypeAliasId; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct TypeAlias { + name: Name, + ast_id: HirAstId, + parameters: Range, +} + +#[newtype_index] +pub struct TypeParameterId; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum TypeParameter { + TypeVar(TypeParameterTypeVar), + ParamSpec(TypeParameterParamSpec), + TypeVarTuple(TypeParameterTypeVarTuple), +} + +impl TypeParameter { + pub fn ast_id(&self) -> HirAstId { + match self { + TypeParameter::TypeVar(type_var) => type_var.ast_id.upcast(), + TypeParameter::ParamSpec(param_spec) => param_spec.ast_id.upcast(), + TypeParameter::TypeVarTuple(type_var_tuple) => type_var_tuple.ast_id.upcast(), + } + } +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct TypeParameterTypeVar { + name: Name, + ast_id: HirAstId, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct TypeParameterParamSpec { + name: Name, + ast_id: HirAstId, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct TypeParameterTypeVarTuple { + name: Name, + ast_id: HirAstId, +} + +#[newtype_index] +pub struct GlobalId; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Global { + // TODO track names + ast_id: HirAstId, +} + +#[newtype_index] +pub struct NonLocalId; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct NonLocal { + // TODO track names + ast_id: HirAstId, +} + +pub enum DefinitionId { + Function(FunctionId), + Parameter(ParameterId), + Class(ClassId), + Assignment(AssignmentId), + AnnotatedAssignment(AnnotatedAssignmentId), + Global(GlobalId), + NonLocal(NonLocalId), + TypeParameter(TypeParameterId), + TypeAlias(TypeAlias), +} + +pub enum DefinitionItem { + Function(Function), + Parameter(Parameter), + Class(Class), + Assignment(Assignment), + AnnotatedAssignment(AnnotatedAssignment), + Global(Global), + NonLocal(NonLocal), + TypeParameter(TypeParameter), + TypeAlias(TypeAlias), +} + +// The closest is rust-analyzers item-tree. It only represents "Items" which make the public interface of a module +// (it excludes any other statement or expressions). rust-analyzer uses it as the main input to the name resolution +// algorithm +// > It is the input to the name resolution algorithm, as well as to the queries defined in `adt.rs`, +// > `data.rs`, and most things in `attr.rs`. +// +// > One important purpose of this layer is to provide an "invalidation barrier" for incremental +// > computations: when typing inside an item body, the `ItemTree` of the modified file is typically +// > unaffected, so we don't have to recompute name resolution results or item data (see `data.rs`). +// +// I haven't fully figured this out but I think that this composes the "public" interface of a module? +// But maybe that's too optimistic. +// +// +#[derive(Debug, Clone, Default, Eq, PartialEq)] +pub struct Definitions { + functions: IndexVec, + parameters: IndexVec, + classes: IndexVec, + assignments: IndexVec, + annotated_assignments: IndexVec, + type_aliases: IndexVec, + type_parameters: IndexVec, + globals: IndexVec, + non_locals: IndexVec, +} + +impl Definitions { + pub fn from_module(module: &ModModule, ast_ids: &AstIds, file_id: FileId) -> Self { + let mut visitor = DefinitionsVisitor { + definitions: Definitions::default(), + ast_ids, + file_id, + }; + + visitor.visit_body(&module.body); + + visitor.definitions + } +} + +impl Index for Definitions { + type Output = Function; + + fn index(&self, index: FunctionId) -> &Self::Output { + &self.functions[index] + } +} + +impl Index for Definitions { + type Output = Parameter; + + fn index(&self, index: ParameterId) -> &Self::Output { + &self.parameters[index] + } +} + +impl Index for Definitions { + type Output = Class; + + fn index(&self, index: ClassId) -> &Self::Output { + &self.classes[index] + } +} + +impl Index for Definitions { + type Output = Assignment; + + fn index(&self, index: AssignmentId) -> &Self::Output { + &self.assignments[index] + } +} + +impl Index for Definitions { + type Output = AnnotatedAssignment; + + fn index(&self, index: AnnotatedAssignmentId) -> &Self::Output { + &self.annotated_assignments[index] + } +} + +impl Index for Definitions { + type Output = TypeAlias; + + fn index(&self, index: TypeAliasId) -> &Self::Output { + &self.type_aliases[index] + } +} + +impl Index for Definitions { + type Output = Global; + + fn index(&self, index: GlobalId) -> &Self::Output { + &self.globals[index] + } +} + +impl Index for Definitions { + type Output = NonLocal; + + fn index(&self, index: NonLocalId) -> &Self::Output { + &self.non_locals[index] + } +} + +impl Index for Definitions { + type Output = TypeParameter; + + fn index(&self, index: TypeParameterId) -> &Self::Output { + &self.type_parameters[index] + } +} + +struct DefinitionsVisitor<'a> { + definitions: Definitions, + ast_ids: &'a AstIds, + file_id: FileId, +} + +impl DefinitionsVisitor<'_> { + fn ast_id(&self, node: &N) -> HirAstId { + HirAstId { + file_id: self.file_id, + node_id: self.ast_ids.ast_id(node), + } + } + + fn lower_function_def(&mut self, function: &StmtFunctionDef) -> FunctionId { + let name = Name::new(&function.name); + + let first_type_parameter_id = self.definitions.type_parameters.next_index(); + let mut last_type_parameter_id = first_type_parameter_id; + + if let Some(type_params) = &function.type_params { + for parameter in &type_params.type_params { + let id = self.lower_type_parameter(parameter); + last_type_parameter_id = id; + } + } + + let parameters = self.lower_parameters(&function.parameters); + + self.definitions.functions.push(Function { + name, + ast_id: self.ast_id(function), + parameters, + type_parameters: first_type_parameter_id..last_type_parameter_id, + }) + } + + fn lower_parameters(&mut self, parameters: &ruff_python_ast::Parameters) -> Range { + let first_parameter_id = self.definitions.parameters.next_index(); + let mut last_parameter_id = first_parameter_id; + + for parameter in ¶meters.posonlyargs { + last_parameter_id = self.definitions.parameters.push(Parameter { + kind: ParameterKind::PositionalOnly, + name: Name::new(¶meter.parameter.name), + default: None, + ast_id: self.ast_id(¶meter.parameter), + }); + } + + if let Some(vararg) = ¶meters.vararg { + last_parameter_id = self.definitions.parameters.push(Parameter { + kind: ParameterKind::Vararg, + name: Name::new(&vararg.name), + default: None, + ast_id: self.ast_id(vararg), + }); + } + + for parameter in ¶meters.kwonlyargs { + last_parameter_id = self.definitions.parameters.push(Parameter { + kind: ParameterKind::KeywordOnly, + name: Name::new(¶meter.parameter.name), + default: None, + ast_id: self.ast_id(¶meter.parameter), + }); + } + + if let Some(kwarg) = ¶meters.kwarg { + last_parameter_id = self.definitions.parameters.push(Parameter { + kind: ParameterKind::KeywordOnly, + name: Name::new(&kwarg.name), + default: None, + ast_id: self.ast_id(kwarg), + }); + } + + first_parameter_id..last_parameter_id + } + + fn lower_class_def(&mut self, class: &StmtClassDef) -> ClassId { + let name = Name::new(&class.name); + + self.definitions.classes.push(Class { + name, + ast_id: self.ast_id(class), + }) + } + + fn lower_assignment(&mut self, assignment: &StmtAssign) { + // FIXME handle multiple names + if let Some(Expr::Name(name)) = assignment.targets.first() { + self.definitions.assignments.push(Assignment { + name: Name::new(&name.id), + ast_id: self.ast_id(assignment), + }); + } + } + + fn lower_annotated_assignment(&mut self, annotated_assignment: &StmtAnnAssign) { + if let Expr::Name(name) = &*annotated_assignment.target { + self.definitions + .annotated_assignments + .push(AnnotatedAssignment { + name: Name::new(&name.id), + ast_id: self.ast_id(annotated_assignment), + }); + } + } + + fn lower_type_alias(&mut self, type_alias: &StmtTypeAlias) { + if let Expr::Name(name) = &*type_alias.name { + let name = Name::new(&name.id); + + let lower_parameters_id = self.definitions.type_parameters.next_index(); + let mut last_parameter_id = lower_parameters_id; + + if let Some(type_params) = &type_alias.type_params { + for type_parameter in &type_params.type_params { + let id = self.lower_type_parameter(type_parameter); + last_parameter_id = id; + } + } + + self.definitions.type_aliases.push(TypeAlias { + name, + ast_id: self.ast_id(type_alias), + parameters: lower_parameters_id..last_parameter_id, + }); + } + } + + fn lower_type_parameter(&mut self, type_parameter: &TypeParam) -> TypeParameterId { + match type_parameter { + TypeParam::TypeVar(type_var) => { + self.definitions + .type_parameters + .push(TypeParameter::TypeVar(TypeParameterTypeVar { + name: Name::new(&type_var.name), + ast_id: self.ast_id(type_var), + })) + } + TypeParam::ParamSpec(param_spec) => { + self.definitions + .type_parameters + .push(TypeParameter::ParamSpec(TypeParameterParamSpec { + name: Name::new(¶m_spec.name), + ast_id: self.ast_id(param_spec), + })) + } + TypeParam::TypeVarTuple(type_var_tuple) => { + self.definitions + .type_parameters + .push(TypeParameter::TypeVarTuple(TypeParameterTypeVarTuple { + name: Name::new(&type_var_tuple.name), + ast_id: self.ast_id(type_var_tuple), + })) + } + } + } + + fn lower_import(&mut self, _import: &StmtImport) { + // TODO + } + + fn lower_import_from(&mut self, _import_from: &StmtImportFrom) { + // TODO + } + + fn lower_global(&mut self, global: &StmtGlobal) -> GlobalId { + self.definitions.globals.push(Global { + ast_id: self.ast_id(global), + }) + } + + fn lower_non_local(&mut self, non_local: &StmtNonlocal) -> NonLocalId { + self.definitions.non_locals.push(NonLocal { + ast_id: self.ast_id(non_local), + }) + } + + fn lower_except_handler(&mut self, _except_handler: &ExceptHandlerExceptHandler) { + // TODO + } + + fn lower_with_item(&mut self, _with_item: &WithItem) { + // TODO + } + + fn lower_match_case(&mut self, _match_case: &MatchCase) { + // TODO + } +} + +impl PreorderVisitor<'_> for DefinitionsVisitor<'_> { + fn visit_stmt(&mut self, stmt: &Stmt) { + match stmt { + // Definition statements + Stmt::FunctionDef(definition) => { + self.lower_function_def(definition); + self.visit_body(&definition.body); + } + Stmt::ClassDef(definition) => { + self.lower_class_def(definition); + self.visit_body(&definition.body); + } + Stmt::Assign(assignment) => { + self.lower_assignment(assignment); + } + Stmt::AnnAssign(annotated_assignment) => { + self.lower_annotated_assignment(annotated_assignment); + } + Stmt::TypeAlias(type_alias) => { + self.lower_type_alias(type_alias); + } + + Stmt::Import(import) => self.lower_import(import), + Stmt::ImportFrom(import_from) => self.lower_import_from(import_from), + Stmt::Global(global) => { + self.lower_global(global); + } + Stmt::Nonlocal(non_local) => { + self.lower_non_local(non_local); + } + + // Visit the compound statement bodies because they can contain other definitions. + Stmt::For(_) + | Stmt::While(_) + | Stmt::If(_) + | Stmt::With(_) + | Stmt::Match(_) + | Stmt::Try(_) => { + preorder::walk_stmt(self, stmt); + } + + // Skip over simple statements because they can't contain any other definitions. + Stmt::Return(_) + | Stmt::Delete(_) + | Stmt::AugAssign(_) + | Stmt::Raise(_) + | Stmt::Assert(_) + | Stmt::Expr(_) + | Stmt::Pass(_) + | Stmt::Break(_) + | Stmt::Continue(_) + | Stmt::IpyEscapeCommand(_) => { + // No op + } + } + } + + fn visit_expr(&mut self, _: &'_ Expr) {} + + fn visit_decorator(&mut self, _decorator: &'_ Decorator) {} + + fn visit_except_handler(&mut self, except_handler: &'_ ExceptHandler) { + match except_handler { + ExceptHandler::ExceptHandler(except_handler) => { + self.lower_except_handler(except_handler); + } + } + } + + fn visit_with_item(&mut self, with_item: &'_ WithItem) { + self.lower_with_item(with_item); + } + + fn visit_match_case(&mut self, match_case: &'_ MatchCase) { + self.lower_match_case(match_case); + self.visit_body(&match_case.body); + } +} diff --git a/crates/red_knot/src/lib.rs b/crates/red_knot/src/lib.rs new file mode 100644 index 0000000000..995a957bed --- /dev/null +++ b/crates/red_knot/src/lib.rs @@ -0,0 +1,83 @@ +use std::hash::BuildHasherDefault; +use std::path::{Path, PathBuf}; + +use rustc_hash::{FxHashSet, FxHasher}; + +use crate::files::FileId; + +pub mod ast_ids; +pub mod cache; +pub mod cancellation; +pub mod db; +pub mod files; +pub mod hir; +pub mod lint; +pub mod module; +mod parse; +pub mod program; +pub mod source; +mod symbols; +mod types; +pub mod watch; + +pub(crate) type FxDashMap = dashmap::DashMap>; +#[allow(unused)] +pub(crate) type FxDashSet = dashmap::DashSet>; +pub(crate) type FxIndexSet = indexmap::set::IndexSet>; + +#[derive(Debug)] +pub struct Workspace { + /// TODO this should be a resolved path. We should probably use a newtype wrapper that guarantees that + /// PATH is a UTF-8 path and is normalized. + root: PathBuf, + /// The files that are open in the workspace. + /// + /// * Editor: The files that are actively being edited in the editor (the user has a tab open with the file). + /// * CLI: The resolved files passed as arguments to the CLI. + open_files: FxHashSet, +} + +impl Workspace { + pub fn new(root: PathBuf) -> Self { + Self { + root, + open_files: FxHashSet::default(), + } + } + + pub fn root(&self) -> &Path { + self.root.as_path() + } + + // TODO having the content in workspace feels wrong. + pub fn open_file(&mut self, file_id: FileId) { + self.open_files.insert(file_id); + } + + pub fn close_file(&mut self, file_id: FileId) { + self.open_files.remove(&file_id); + } + + // TODO introduce an `OpenFile` type instead of using an anonymous tuple. + pub fn open_files(&self) -> impl Iterator + '_ { + self.open_files.iter().copied() + } + + pub fn is_file_open(&self, file_id: FileId) -> bool { + self.open_files.contains(&file_id) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Name(smol_str::SmolStr); + +impl Name { + #[inline] + pub fn new(name: &str) -> Self { + Self(smol_str::SmolStr::new(name)) + } + + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs new file mode 100644 index 0000000000..d42bb51a54 --- /dev/null +++ b/crates/red_knot/src/lint.rs @@ -0,0 +1,124 @@ +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::StringLiteral; + +use crate::cache::KeyValueCache; +use crate::db::{HasJar, SourceDb, SourceJar}; +use crate::files::FileId; + +pub(crate) fn lint_syntax(db: &Db, file_id: FileId) -> Diagnostics +where + Db: SourceDb + HasJar, +{ + let storage = &db.jar().lint_syntax; + + storage.get(&file_id, |file_id| { + let mut diagnostics = Vec::new(); + + let source = db.source(*file_id); + lint_lines(source.text(), &mut diagnostics); + + let parsed = db.parse(*file_id); + + if parsed.errors().is_empty() { + let ast = parsed.ast(); + + let mut visitor = SyntaxLintVisitor { + diagnostics, + source: source.text(), + }; + visitor.visit_body(&ast.body); + diagnostics = visitor.diagnostics; + } else { + diagnostics.extend(parsed.errors().iter().map(std::string::ToString::to_string)); + } + + Diagnostics::from(diagnostics) + }) +} + +pub(crate) fn lint_lines(source: &str, diagnostics: &mut Vec) { + for (line_number, line) in source.lines().enumerate() { + if line.len() < 88 { + continue; + } + + let char_count = line.chars().count(); + if char_count > 88 { + diagnostics.push(format!( + "Line {} is too long ({} characters)", + line_number + 1, + char_count + )); + } + } +} + +#[derive(Debug)] +struct SyntaxLintVisitor<'a> { + diagnostics: Vec, + source: &'a str, +} + +impl Visitor<'_> for SyntaxLintVisitor<'_> { + fn visit_string_literal(&mut self, string_literal: &'_ StringLiteral) { + // A very naive implementation of use double quotes + let text = &self.source[string_literal.range]; + + if text.starts_with('\'') { + self.diagnostics + .push("Use double quotes for strings".to_string()); + } + } +} + +#[derive(Debug, Clone)] +pub enum Diagnostics { + Empty, + List(Arc>), +} + +impl Diagnostics { + pub fn as_slice(&self) -> &[String] { + match self { + Diagnostics::Empty => &[], + Diagnostics::List(list) => list.as_slice(), + } + } +} + +impl Deref for Diagnostics { + type Target = [String]; + fn deref(&self) -> &Self::Target { + self.as_slice() + } +} + +impl From> for Diagnostics { + fn from(value: Vec) -> Self { + if value.is_empty() { + Diagnostics::Empty + } else { + Diagnostics::List(Arc::new(value)) + } + } +} + +#[derive(Default, Debug)] +pub struct LintSyntaxStorage(KeyValueCache); + +impl Deref for LintSyntaxStorage { + type Target = KeyValueCache; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for LintSyntaxStorage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/crates/red_knot/src/main.rs b/crates/red_knot/src/main.rs new file mode 100644 index 0000000000..9bfb27bd0b --- /dev/null +++ b/crates/red_knot/src/main.rs @@ -0,0 +1,399 @@ +use std::collections::hash_map::Entry; +use std::num::NonZeroUsize; +use std::path::Path; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use rustc_hash::FxHashMap; +use tracing::subscriber::Interest; +use tracing::{Level, Metadata}; +use tracing_subscriber::filter::LevelFilter; +use tracing_subscriber::layer::{Context, Filter, SubscriberExt}; +use tracing_subscriber::{Layer, Registry}; +use tracing_tree::time::Uptime; + +use red_knot::cancellation::CancellationSource; +use red_knot::db::{HasJar, SourceDb, SourceJar}; +use red_knot::files::FileId; +use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind}; +use red_knot::program::{FileChange, FileChangeKind, Program}; +use red_knot::watch::FileWatcher; +use red_knot::{files, Workspace}; + +#[allow( + clippy::dbg_macro, + clippy::print_stdout, + clippy::unnecessary_wraps, + clippy::print_stderr +)] +fn main() -> anyhow::Result<()> { + setup_tracing(); + + let arguments: Vec<_> = std::env::args().collect(); + + if arguments.len() < 2 { + eprintln!("Usage: red_knot "); + return Err(anyhow::anyhow!("Invalid arguments")); + } + + let entry_point = Path::new(&arguments[1]); + + if !entry_point.exists() { + eprintln!("The entry point does not exist."); + return Err(anyhow::anyhow!("Invalid arguments")); + } + + if !entry_point.is_file() { + eprintln!("The entry point is not a file."); + return Err(anyhow::anyhow!("Invalid arguments")); + } + + let files = files::Files::default(); + let workspace_folder = entry_point.parent().unwrap(); + let mut workspace = Workspace::new(workspace_folder.to_path_buf()); + + let workspace_search_path = ModuleSearchPath::new( + workspace.root().to_path_buf(), + ModuleSearchPathKind::FirstParty, + ); + + let entry_id = files.intern(entry_point); + + let mut program = Program::new(vec![workspace_search_path], files.clone()); + + workspace.open_file(entry_id); + + let (sender, receiver) = crossbeam_channel::bounded( + std::thread::available_parallelism() + .map(NonZeroUsize::get) + .unwrap_or(50) + .max(4), // TODO: Both these numbers are very arbitrary. Pick sensible defaults. + ); + + // Listen to Ctrl+C and abort the watch mode. + let abort_sender = Mutex::new(Some(sender.clone())); + ctrlc::set_handler(move || { + let mut lock = abort_sender.lock().unwrap(); + + if let Some(sender) = lock.take() { + sender.send(Message::Exit).unwrap(); + } + })?; + + // Watch for file changes and re-trigger the analysis. + let file_changes_sender = sender.clone(); + + let mut file_watcher = FileWatcher::new( + move |changes| { + file_changes_sender + .send(Message::FileChanges(changes)) + .unwrap(); + }, + files.clone(), + )?; + + file_watcher.watch_folder(workspace_folder)?; + + let files_to_check = vec![entry_id]; + + // Main loop that runs until the user exits the program + // Runs the analysis for each changed file. Cancels the analysis if a new change is detected. + loop { + let changes = { + tracing::trace!("Main Loop: Tick"); + + // Token to cancel the analysis if a new change is detected. + let run_cancellation_token_source = CancellationSource::new(); + let run_cancellation_token = run_cancellation_token_source.token(); + + // Tracks the number of pending analysis runs. + let pending_analysis = Arc::new(AtomicUsize::new(0)); + + // Take read-only references that are copy and Send. + let program = &program; + let workspace = &workspace; + + let receiver = receiver.clone(); + let started_analysis = pending_analysis.clone(); + + // Orchestration task. Ideally, we would run this on main but we should start it as soon as possible so that + // we avoid scheduling tasks when we already know that we're about to exit or cancel the analysis because of a file change. + // This uses `std::thread::spawn` because we don't want it to run inside of the thread pool + // or this code deadlocks when using a thread pool of the size 1. + let orchestration_handle = std::thread::spawn(move || { + fn consume_pending_messages( + receiver: &crossbeam_channel::Receiver, + mut aggregated_changes: AggregatedChanges, + ) -> NextTickCommand { + loop { + // Consume possibly incoming file change messages before running a new analysis, but don't wait for more than 100ms. + crossbeam_channel::select! { + recv(receiver) -> message => { + match message { + Ok(Message::Exit) => { + return NextTickCommand::Exit; + } + Ok(Message::FileChanges(file_changes)) => { + aggregated_changes.extend(file_changes); + } + + Ok(Message::AnalysisCancelled | Message::AnalysisCompleted(_)) => { + unreachable!( + "All analysis should have been completed at this time" + ); + }, + + Err(_) => { + // There are no more senders, no point in waiting for more messages + break; + } + } + }, + default(std::time::Duration::from_millis(100)) => { + break; + } + } + } + + NextTickCommand::FileChanges(aggregated_changes) + } + + let mut diagnostics = Vec::new(); + let mut aggregated_changes = AggregatedChanges::default(); + + for message in &receiver { + match message { + Message::AnalysisCompleted(file_diagnostics) => { + diagnostics.extend_from_slice(&file_diagnostics); + + if pending_analysis.fetch_sub(1, Ordering::SeqCst) == 1 { + // Analysis completed, print the diagnostics. + dbg!(&diagnostics); + } + } + + Message::AnalysisCancelled => { + if pending_analysis.fetch_sub(1, Ordering::SeqCst) == 1 { + return consume_pending_messages(&receiver, aggregated_changes); + } + } + + Message::Exit => { + run_cancellation_token_source.cancel(); + + // Don't consume any outstanding messages because we're exiting anyway. + return NextTickCommand::Exit; + } + + Message::FileChanges(changes) => { + // Request cancellation, but wait until all analysis tasks have completed to + // avoid stale messages in the next main loop. + run_cancellation_token_source.cancel(); + + aggregated_changes.extend(changes); + + if pending_analysis.load(Ordering::SeqCst) == 0 { + return consume_pending_messages(&receiver, aggregated_changes); + } + } + } + } + + // This can be reached if there's no Ctrl+C and no file watcher handler. + // In that case, assume that we don't run in watch mode and exit. + NextTickCommand::Exit + }); + + // Star the analysis task on the thread pool and wait until they complete. + rayon::scope(|scope| { + for file in &files_to_check { + let cancellation_token = run_cancellation_token.clone(); + if cancellation_token.is_cancelled() { + break; + } + + let sender = sender.clone(); + + started_analysis.fetch_add(1, Ordering::SeqCst); + + // TODO: How do we allow the host to control the number of threads used? + // Or should we just assume that each host implements its own main loop, + // I don't think that's entirely unreasonable but we should avoid + // having different main loops per host AND command (e.g. format vs check vs lint) + scope.spawn(move |_| { + if cancellation_token.is_cancelled() { + tracing::trace!("Exit analysis because cancellation was requested."); + sender.send(Message::AnalysisCancelled).unwrap(); + return; + } + + // TODO schedule the dependencies. + let mut diagnostics = Vec::new(); + + if workspace.is_file_open(*file) { + diagnostics.extend_from_slice(&program.lint_syntax(*file)); + } + + sender + .send(Message::AnalysisCompleted(diagnostics)) + .unwrap(); + }); + } + }); + + // Wait for the orchestration task to complete. This either returns the file changes + // or instructs the main loop to exit. + match orchestration_handle.join().unwrap() { + NextTickCommand::FileChanges(changes) => changes, + NextTickCommand::Exit => { + break; + } + } + }; + + // We have a mutable reference here and can perform all necessary invalidations. + program.apply_changes(changes.iter()); + } + + let source_jar: &SourceJar = program.jar(); + + dbg!(source_jar.parsed.statistics()); + dbg!(source_jar.sources.statistics()); + + Ok(()) +} + +enum Message { + AnalysisCompleted(Vec), + AnalysisCancelled, + Exit, + FileChanges(Vec), +} + +#[derive(Default, Debug)] +struct AggregatedChanges { + changes: FxHashMap, +} + +impl AggregatedChanges { + fn add(&mut self, change: FileChange) { + match self.changes.entry(change.file_id()) { + Entry::Occupied(mut entry) => { + let merged = entry.get_mut(); + + match (merged, change.kind()) { + (FileChangeKind::Created, FileChangeKind::Deleted) => { + // Deletion after creations means that ruff never saw the file. + entry.remove(); + } + (FileChangeKind::Created, FileChangeKind::Modified) => { + // No-op, for ruff, modifying a file that it doesn't yet know that it exists is still considered a creation. + } + + (FileChangeKind::Modified, FileChangeKind::Created) => { + // Uhh, that should probably not happen. Continue considering it a modification. + } + + (FileChangeKind::Modified, FileChangeKind::Deleted) => { + *entry.get_mut() = FileChangeKind::Deleted; + } + + (FileChangeKind::Deleted, FileChangeKind::Created) => { + *entry.get_mut() = FileChangeKind::Modified; + } + + (FileChangeKind::Deleted, FileChangeKind::Modified) => { + // That's weird, but let's consider it a modification. + *entry.get_mut() = FileChangeKind::Modified; + } + + (FileChangeKind::Created, FileChangeKind::Created) + | (FileChangeKind::Modified, FileChangeKind::Modified) + | (FileChangeKind::Deleted, FileChangeKind::Deleted) => { + // No-op transitions. Some of them should be impossible but we handle them anyway. + } + } + } + Entry::Vacant(entry) => { + entry.insert(change.kind()); + } + } + } + + fn extend(&mut self, changes: I) + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let iter = changes.into_iter(); + self.changes.reserve(iter.len()); + + for change in iter { + self.add(change); + } + } + + fn iter(&self) -> impl Iterator + '_ { + self.changes + .iter() + .map(|(id, kind)| FileChange::new(*id, *kind)) + } +} + +enum NextTickCommand { + /// Exit the main loop in the next tick + Exit, + /// Apply the given changes in the next main loop tick. + FileChanges(AggregatedChanges), +} + +fn setup_tracing() { + let subscriber = Registry::default().with( + tracing_tree::HierarchicalLayer::default() + .with_indent_lines(true) + .with_indent_amount(2) + .with_bracketed_fields(true) + .with_targets(true) + .with_writer(|| Box::new(std::io::stderr())) + .with_timer(Uptime::default()) + .with_filter(LoggingFilter { + trace_level: Level::TRACE, + }), + ); + + tracing::subscriber::set_global_default(subscriber).unwrap(); +} + +struct LoggingFilter { + trace_level: Level, +} + +impl LoggingFilter { + fn is_enabled(&self, meta: &Metadata<'_>) -> bool { + let filter = if meta.target().starts_with("red_knot") || meta.target().starts_with("ruff") { + self.trace_level + } else { + Level::INFO + }; + + meta.level() <= &filter + } +} + +impl Filter for LoggingFilter { + fn enabled(&self, meta: &Metadata<'_>, _cx: &Context<'_, S>) -> bool { + self.is_enabled(meta) + } + + fn callsite_enabled(&self, meta: &'static Metadata<'static>) -> Interest { + if self.is_enabled(meta) { + Interest::always() + } else { + Interest::never() + } + } + + fn max_level_hint(&self) -> Option { + Some(LevelFilter::from_level(self.trace_level)) + } +} diff --git a/crates/red_knot/src/module.rs b/crates/red_knot/src/module.rs new file mode 100644 index 0000000000..5ce422f40e --- /dev/null +++ b/crates/red_knot/src/module.rs @@ -0,0 +1,911 @@ +use std::fmt::Formatter; +use std::path::{Path, PathBuf}; +use std::sync::atomic::AtomicU32; +use std::sync::Arc; + +use dashmap::mapref::entry::Entry; + +use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::files::FileId; +use crate::FxDashMap; + +/// ID uniquely identifying a module. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct Module(u32); + +impl Module { + pub fn name(&self, db: &Db) -> ModuleName + where + Db: HasJar, + { + let modules = &db.jar().module_resolver; + + modules.modules.get(self).unwrap().name.clone() + } + + pub fn path(&self, db: &Db) -> ModulePath + where + Db: HasJar, + { + let modules = &db.jar().module_resolver; + + modules.modules.get(self).unwrap().path.clone() + } +} + +/// A module name, e.g. `foo.bar`. +/// +/// Always normalized to the absolute form (never a relative module name). +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct ModuleName(smol_str::SmolStr); + +impl ModuleName { + pub fn new(name: &str) -> Self { + debug_assert!(!name.is_empty()); + + Self(smol_str::SmolStr::new(name)) + } + + pub fn relative(_dots: u32, name: &str, _to: &Path) -> Self { + // FIXME: Take `to` and `dots` into account. + Self(smol_str::SmolStr::new(name)) + } + + pub fn from_relative_path(path: &Path) -> Option { + let path = if path.ends_with("__init__.py") || path.ends_with("__init__.pyi") { + path.parent()? + } else { + path + }; + + let name = if let Some(parent) = path.parent() { + let mut name = String::with_capacity(path.as_os_str().len()); + + for component in parent.components() { + name.push_str(component.as_os_str().to_str()?); + name.push('.'); + } + + // SAFETY: Unwrap is safe here or `parent` would have returned `None`. + name.push_str(path.file_stem().unwrap().to_str()?); + + smol_str::SmolStr::from(name) + } else { + smol_str::SmolStr::new(path.file_stem()?.to_str()?) + }; + + Some(Self(name)) + } + + pub fn components(&self) -> impl DoubleEndedIterator { + self.0.split('.') + } + + pub fn parent(&self) -> Option { + let (_, parent) = self.0.rsplit_once('.')?; + + Some(Self(smol_str::SmolStr::new(parent))) + } + + pub fn starts_with(&self, other: &ModuleName) -> bool { + self.0.starts_with(other.0.as_str()) + } + + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for ModuleName { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +/// A search path in which to search modules. +/// Corresponds to a path in [`sys.path`](https://docs.python.org/3/library/sys_path_init.html) at runtime. +/// +/// Cloning a search path is cheap because it's an `Arc`. +#[derive(Clone, PartialEq, Eq)] +pub struct ModuleSearchPath { + inner: Arc, +} + +impl ModuleSearchPath { + pub fn new(path: PathBuf, kind: ModuleSearchPathKind) -> Self { + Self { + inner: Arc::new(ModuleSearchPathInner { path, kind }), + } + } + + pub fn kind(&self) -> ModuleSearchPathKind { + self.inner.kind + } + + pub fn path(&self) -> &Path { + &self.inner.path + } +} + +impl std::fmt::Debug for ModuleSearchPath { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +#[derive(Debug, Eq, PartialEq)] +struct ModuleSearchPathInner { + path: PathBuf, + kind: ModuleSearchPathKind, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum ModuleSearchPathKind { + // Project dependency + FirstParty, + + // e.g. site packages + ThirdParty, + + // e.g. built-in modules, typeshed + StandardLibrary, +} + +#[derive(Debug, Eq, PartialEq)] +pub struct ModuleData { + name: ModuleName, + path: ModulePath, +} + +////////////////////////////////////////////////////// +// Queries +////////////////////////////////////////////////////// + +/// Resolves a module name to a module id +/// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query. +/// For this to work with salsa, it would be necessary to intern all `ModuleName`s. +#[tracing::instrument(level = "trace", skip(db))] +pub fn resolve_module(db: &Db, name: ModuleName) -> Option +where + Db: SemanticDb + HasJar, +{ + let jar = db.jar(); + let modules = &jar.module_resolver; + + let entry = modules.by_name.entry(name.clone()); + + match entry { + Entry::Occupied(entry) => Some(*entry.get()), + Entry::Vacant(entry) => { + let (root_path, absolute_path) = resolve_name(&name, &modules.search_paths)?; + let normalized = absolute_path.canonicalize().ok()?; + + let file_id = db.file_id(&normalized); + let path = ModulePath::new(root_path.clone(), file_id); + + let id = Module( + modules + .next_module_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ); + + modules + .modules + .insert(id, Arc::from(ModuleData { name, path })); + + // A path can map to multiple modules because of symlinks: + // ``` + // foo.py + // bar.py -> foo.py + // ``` + // Here, both `foo` and `bar` resolve to the same module but through different paths. + // That's why we need to insert the absolute path and not the normalized path here. + modules.by_path.insert(absolute_path, id); + + entry.insert_entry(id); + + Some(id) + } + } +} + +////////////////////////////////////////////////////// +// Mutations +////////////////////////////////////////////////////// + +/// Changes the module search paths to `search_paths`. +pub fn set_module_search_paths(db: &mut Db, search_paths: Vec) +where + Db: SemanticDb + HasJar, +{ + let jar = db.jar_mut(); + + jar.module_resolver = ModuleResolver::new(search_paths); +} + +/// Resolves the module id for the file with the given id. +/// +/// Returns `None` if the file is not a module in `sys.path`. +pub fn file_to_module(db: &mut Db, file: FileId) -> Option +where + Db: SemanticDb + HasJar, +{ + let path = db.file_path(file); + path_to_module(db, &path) +} + +/// Resolves the module id for the given path. +/// +/// Returns `None` if the path is not a module in `sys.path`. +// WARNING!: It's important that this method takes `&mut self`. Without, the implementation is prone to race conditions. +// Note: This won't work with salsa because `Path` is not an ingredient. +pub fn path_to_module(db: &mut Db, path: &Path) -> Option +where + Db: SemanticDb + HasJar, +{ + let jar = db.jar_mut(); + let modules = &mut jar.module_resolver; + debug_assert!(path.is_absolute()); + + if let Some(existing) = modules.by_path.get(path) { + return Some(*existing); + } + + let root_path = modules + .search_paths + .iter() + .find(|root| path.starts_with(root.path()))? + .clone(); + + // SAFETY: `strip_prefix` is guaranteed to succeed because we search the root that is a prefix of the path. + let relative_path = path.strip_prefix(root_path.path()).unwrap(); + let module_name = ModuleName::from_relative_path(relative_path)?; + + // Resolve the module name to see if Python would resolve the name to the same path. + // If it doesn't, then that means that multiple modules have the same in different + // root paths, but that the module corresponding to the past path is in a lower priority path, + // in which case we ignore it. + let module_id = resolve_module(db, module_name)?; + // Note: Guaranteed to be race-free because we're holding a mutable reference of `self` here. + let module_path = module_id.path(db); + + if module_path.root() == &root_path { + let normalized = path.canonicalize().ok()?; + let interned_normalized = db.file_id(&normalized); + + if interned_normalized != module_path.file() { + // This path is for a module with the same name but with a different precedence. For example: + // ``` + // src/foo.py + // src/foo/__init__.py + // ``` + // The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`. + // That means we need to ignore `src/foo.py` even though it resolves to the same module name. + return None; + } + + // Path has been inserted by `resolved` + Some(module_id) + } else { + // This path is for a module with the same name but in a module search path with a lower priority. + // Ignore it. + None + } +} + +/// Adds a module to the resolver. +/// +/// Returns `None` if the path doesn't resolve to a module. +/// +/// Returns `Some` with the id of the module and the ids of the modules that need re-resolving +/// because they were part of a namespace package and might now resolve differently. +/// Note: This won't work with salsa because `Path` is not an ingredient. +pub fn add_module(db: &mut Db, path: &Path) -> Option<(Module, Vec>)> +where + Db: SemanticDb + HasJar, +{ + // No locking is required because we're holding a mutable reference to `modules`. + + // TODO This needs tests + + // Note: Intentionally by-pass caching here. Module should not be in the cache yet. + let module = path_to_module(db, path)?; + + // The code below is to handle the addition of `__init__.py` files. + // When an `__init__.py` file is added, we need to remove all modules that are part of the same package. + // For example, an `__init__.py` is added to `foo`, we need to remove `foo.bar`, `foo.baz`, etc. + // because they were namespace packages before and could have been from different search paths. + let Some(filename) = path.file_name() else { + return Some((module, Vec::new())); + }; + + if !matches!(filename.to_str(), Some("__init__.py" | "__init__.pyi")) { + return Some((module, Vec::new())); + } + + let Some(parent_name) = module.name(db).parent() else { + return Some((module, Vec::new())); + }; + + let mut to_remove = Vec::new(); + + let jar = db.jar_mut(); + let modules = &mut jar.module_resolver; + + modules.by_path.retain(|_, id| { + if modules + .modules + .get(id) + .unwrap() + .name + .starts_with(&parent_name) + { + to_remove.push(*id); + false + } else { + true + } + }); + + // TODO remove need for this vec + let mut removed = Vec::with_capacity(to_remove.len()); + for id in &to_remove { + removed.push(modules.remove_module_by_id(*id)); + } + + Some((module, removed)) +} + +#[derive(Default)] +pub struct ModuleResolver { + /// The search paths where modules are located (and searched). Corresponds to `sys.path` at runtime. + search_paths: Vec, + + // Locking: Locking is done by acquiring a (write) lock on `by_name`. This is because `by_name` is the primary + // lookup method. Acquiring locks in any other ordering can result in deadlocks. + /// Resolves a module name to it's module id. + by_name: FxDashMap, + + /// All known modules, indexed by the module id. + modules: FxDashMap>, + + /// Lookup from absolute path to module. + /// The same module might be reachable from different paths when symlinks are involved. + by_path: FxDashMap, + next_module_id: AtomicU32, +} + +impl ModuleResolver { + pub fn new(search_paths: Vec) -> Self { + Self { + search_paths, + modules: FxDashMap::default(), + by_name: FxDashMap::default(), + by_path: FxDashMap::default(), + next_module_id: AtomicU32::new(0), + } + } + + pub fn remove_module(&mut self, path: &Path) { + // No locking is required because we're holding a mutable reference to `self`. + let Some((_, id)) = self.by_path.remove(path) else { + return; + }; + + self.remove_module_by_id(id); + } + + fn remove_module_by_id(&mut self, id: Module) -> Arc { + let (_, module) = self.modules.remove(&id).unwrap(); + + self.by_name.remove(&module.name).unwrap(); + + // It's possible that multiple paths map to the same id. Search all other paths referencing the same module id. + self.by_path.retain(|_, current_id| *current_id != id); + + module + } +} + +#[allow(clippy::missing_fields_in_debug)] +impl std::fmt::Debug for ModuleResolver { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ModuleResolver") + .field("search_paths", &self.search_paths) + .field("modules", &self.by_name) + .finish() + } +} + +/// The resolved path of a module. +/// +/// It should be highly likely that the file still exists when accessing but it isn't 100% guaranteed +/// because the file could have been deleted between resolving the module name and accessing it. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ModulePath { + root: ModuleSearchPath, + file_id: FileId, +} + +impl ModulePath { + pub fn new(root: ModuleSearchPath, file_id: FileId) -> Self { + Self { root, file_id } + } + + pub fn root(&self) -> &ModuleSearchPath { + &self.root + } + + pub fn file(&self) -> FileId { + self.file_id + } +} + +fn resolve_name( + name: &ModuleName, + search_paths: &[ModuleSearchPath], +) -> Option<(ModuleSearchPath, PathBuf)> { + for search_path in search_paths { + let mut components = name.components(); + let module_name = components.next_back()?; + + match resolve_package(search_path, components) { + Ok(resolved_package) => { + let mut package_path = resolved_package.path; + + package_path.push(module_name); + + // Must be a `__init__.pyi` or `__init__.py` or it isn't a package. + if package_path.is_dir() { + package_path.push("__init__"); + } + + // TODO Implement full https://peps.python.org/pep-0561/#type-checker-module-resolution-order resolution + let stub = package_path.with_extension("pyi"); + + if stub.is_file() { + return Some((search_path.clone(), stub)); + } + + let module = package_path.with_extension("py"); + + if module.is_file() { + return Some((search_path.clone(), module)); + } + + // For regular packages, don't search the next search path. All files of that + // package must be in the same location + if resolved_package.kind.is_regular_package() { + return None; + } + } + Err(parent_kind) => { + if parent_kind.is_regular_package() { + // For regular packages, don't search the next search path. + return None; + } + } + } + } + + None +} + +fn resolve_package<'a, I>( + module_search_path: &ModuleSearchPath, + components: I, +) -> Result +where + I: Iterator, +{ + let mut package_path = module_search_path.path().to_path_buf(); + + // `true` if inside a folder that is a namespace package (has no `__init__.py`). + // Namespace packages are special because they can be spread across multiple search paths. + // https://peps.python.org/pep-0420/ + let mut in_namespace_package = false; + + // `true` if resolving a sub-package. For example, `true` when resolving `bar` of `foo.bar`. + let mut in_sub_package = false; + + // For `foo.bar.baz`, test that `foo` and `baz` both contain a `__init__.py`. + for folder in components { + package_path.push(folder); + + let has_init_py = package_path.join("__init__.py").is_file() + || package_path.join("__init__.pyi").is_file(); + + if has_init_py { + in_namespace_package = false; + } else if package_path.is_dir() { + // A directory without an `__init__.py` is a namespace package, continue with the next folder. + in_namespace_package = true; + } else if in_namespace_package { + // Package not found but it is part of a namespace package. + return Err(PackageKind::Namespace); + } else if in_sub_package { + // A regular sub package wasn't found. + return Err(PackageKind::Regular); + } else { + // We couldn't find `foo` for `foo.bar.baz`, search the next search path. + return Err(PackageKind::Root); + } + + in_sub_package = true; + } + + let kind = if in_namespace_package { + PackageKind::Namespace + } else if in_sub_package { + PackageKind::Regular + } else { + PackageKind::Root + }; + + Ok(ResolvedPackage { + kind, + path: package_path, + }) +} + +#[derive(Debug)] +struct ResolvedPackage { + path: PathBuf, + kind: PackageKind, +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +enum PackageKind { + /// A root package or module. E.g. `foo` in `foo.bar.baz` or just `foo`. + Root, + + /// A regular sub-package where the parent contains an `__init__.py`. For example `bar` in `foo.bar` when the `foo` directory contains an `__init__.py`. + Regular, + + /// A sub-package in a namespace package. A namespace package is a package without an `__init__.py`. + /// + /// For example, `bar` in `foo.bar` if the `foo` directory contains no `__init__.py`. + Namespace, +} + +impl PackageKind { + const fn is_regular_package(self) -> bool { + matches!(self, PackageKind::Regular) + } +} + +#[cfg(test)] +mod tests { + use crate::db::tests::TestDb; + use crate::db::{SemanticDb, SourceDb}; + use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind}; + + struct TestCase { + temp_dir: tempfile::TempDir, + db: TestDb, + + src: ModuleSearchPath, + site_packages: ModuleSearchPath, + } + + fn create_resolver() -> std::io::Result { + let temp_dir = tempfile::tempdir()?; + + let src = temp_dir.path().join("src"); + let site_packages = temp_dir.path().join("site_packages"); + + std::fs::create_dir(&src)?; + std::fs::create_dir(&site_packages)?; + + let src = ModuleSearchPath::new(src.canonicalize()?, ModuleSearchPathKind::FirstParty); + let site_packages = ModuleSearchPath::new( + site_packages.canonicalize()?, + ModuleSearchPathKind::ThirdParty, + ); + + let roots = vec![src.clone(), site_packages.clone()]; + + let mut db = TestDb::default(); + db.set_module_search_paths(roots); + + Ok(TestCase { + temp_dir, + db, + src, + site_packages, + }) + } + + #[test] + fn first_party_module() -> std::io::Result<()> { + let TestCase { + mut db, + src, + temp_dir: _temp_dir, + .. + } = create_resolver()?; + + let foo_path = src.path().join("foo.py"); + std::fs::write(&foo_path, "print('Hello, world!')")?; + + let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + + assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo"))); + + assert_eq!(ModuleName::new("foo"), foo_module.name(&db)); + assert_eq!(&src, foo_module.path(&db).root()); + assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file())); + + assert_eq!(Some(foo_module), db.path_to_module(&foo_path)); + + Ok(()) + } + + #[test] + fn resolve_package() -> std::io::Result<()> { + let TestCase { + src, + mut db, + temp_dir: _temp_dir, + .. + } = create_resolver()?; + + let foo_dir = src.path().join("foo"); + let foo_path = foo_dir.join("__init__.py"); + std::fs::create_dir(&foo_dir)?; + std::fs::write(&foo_path, "print('Hello, world!')")?; + + let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + + assert_eq!(ModuleName::new("foo"), foo_module.name(&db)); + assert_eq!(&src, foo_module.path(&db).root()); + assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file())); + + assert_eq!(Some(foo_module), db.path_to_module(&foo_path)); + + // Resolving by directory doesn't resolve to the init file. + assert_eq!(None, db.path_to_module(&foo_dir)); + + Ok(()) + } + + #[test] + fn package_priority_over_module() -> std::io::Result<()> { + let TestCase { + mut db, + temp_dir: _temp_dir, + src, + .. + } = create_resolver()?; + + let foo_dir = src.path().join("foo"); + let foo_init = foo_dir.join("__init__.py"); + std::fs::create_dir(&foo_dir)?; + std::fs::write(&foo_init, "print('Hello, world!')")?; + + let foo_py = src.path().join("foo.py"); + std::fs::write(&foo_py, "print('Hello, world!')")?; + + let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + + assert_eq!(&src, foo_module.path(&db).root()); + assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db).file())); + + assert_eq!(Some(foo_module), db.path_to_module(&foo_init)); + assert_eq!(None, db.path_to_module(&foo_py)); + + Ok(()) + } + + #[test] + fn typing_stub_over_module() -> std::io::Result<()> { + let TestCase { + mut db, + src, + temp_dir: _temp_dir, + .. + } = create_resolver()?; + + let foo_stub = src.path().join("foo.pyi"); + let foo_py = src.path().join("foo.py"); + std::fs::write(&foo_stub, "x: int")?; + std::fs::write(&foo_py, "print('Hello, world!')")?; + + let foo = db.resolve_module(ModuleName::new("foo")).unwrap(); + + assert_eq!(&src, foo.path(&db).root()); + assert_eq!(&foo_stub, &*db.file_path(foo.path(&db).file())); + + assert_eq!(Some(foo), db.path_to_module(&foo_stub)); + assert_eq!(None, db.path_to_module(&foo_py)); + + Ok(()) + } + + #[test] + fn sub_packages() -> std::io::Result<()> { + let TestCase { + mut db, + src, + temp_dir: _temp_dir, + .. + } = create_resolver()?; + + let foo = src.path().join("foo"); + let bar = foo.join("bar"); + let baz = bar.join("baz.py"); + + std::fs::create_dir_all(&bar)?; + std::fs::write(foo.join("__init__.py"), "")?; + std::fs::write(bar.join("__init__.py"), "")?; + std::fs::write(&baz, "print('Hello, world!')")?; + + let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz")).unwrap(); + + assert_eq!(&src, baz_module.path(&db).root()); + assert_eq!(&baz, &*db.file_path(baz_module.path(&db).file())); + + assert_eq!(Some(baz_module), db.path_to_module(&baz)); + + Ok(()) + } + + #[test] + fn namespace_package() -> std::io::Result<()> { + let TestCase { + mut db, + temp_dir: _, + src, + site_packages, + } = create_resolver()?; + + // From [PEP420](https://peps.python.org/pep-0420/#nested-namespace-packages). + // But uses `src` for `project1` and `site_packages2` for `project2`. + // ``` + // src + // parent + // child + // one.py + // site_packages + // parent + // child + // two.py + // ``` + + let parent1 = src.path().join("parent"); + let child1 = parent1.join("child"); + let one = child1.join("one.py"); + + std::fs::create_dir_all(child1)?; + std::fs::write(&one, "print('Hello, world!')")?; + + let parent2 = site_packages.path().join("parent"); + let child2 = parent2.join("child"); + let two = child2.join("two.py"); + + std::fs::create_dir_all(&child2)?; + std::fs::write(&two, "print('Hello, world!')")?; + + let one_module = db + .resolve_module(ModuleName::new("parent.child.one")) + .unwrap(); + + assert_eq!(Some(one_module), db.path_to_module(&one)); + + let two_module = db + .resolve_module(ModuleName::new("parent.child.two")) + .unwrap(); + assert_eq!(Some(two_module), db.path_to_module(&two)); + + Ok(()) + } + + #[test] + fn regular_package_in_namespace_package() -> std::io::Result<()> { + let TestCase { + mut db, + temp_dir: _, + src, + site_packages, + } = create_resolver()?; + + // Adopted test case from the [PEP420 examples](https://peps.python.org/pep-0420/#nested-namespace-packages). + // The `src/parent/child` package is a regular package. Therefore, `site_packages/parent/child/two.py` should not be resolved. + // ``` + // src + // parent + // child + // one.py + // site_packages + // parent + // child + // two.py + // ``` + + let parent1 = src.path().join("parent"); + let child1 = parent1.join("child"); + let one = child1.join("one.py"); + + std::fs::create_dir_all(&child1)?; + std::fs::write(child1.join("__init__.py"), "print('Hello, world!')")?; + std::fs::write(&one, "print('Hello, world!')")?; + + let parent2 = site_packages.path().join("parent"); + let child2 = parent2.join("child"); + let two = child2.join("two.py"); + + std::fs::create_dir_all(&child2)?; + std::fs::write(two, "print('Hello, world!')")?; + + let one_module = db + .resolve_module(ModuleName::new("parent.child.one")) + .unwrap(); + + assert_eq!(Some(one_module), db.path_to_module(&one)); + + assert_eq!(None, db.resolve_module(ModuleName::new("parent.child.two"))); + Ok(()) + } + + #[test] + fn module_search_path_priority() -> std::io::Result<()> { + let TestCase { + mut db, + src, + site_packages, + temp_dir: _temp_dir, + } = create_resolver()?; + + let foo_src = src.path().join("foo.py"); + let foo_site_packages = site_packages.path().join("foo.py"); + + std::fs::write(&foo_src, "")?; + std::fs::write(&foo_site_packages, "")?; + + let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + + assert_eq!(&src, foo_module.path(&db).root()); + assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db).file())); + + assert_eq!(Some(foo_module), db.path_to_module(&foo_src)); + assert_eq!(None, db.path_to_module(&foo_site_packages)); + + Ok(()) + } + + #[test] + #[cfg(target_family = "unix")] + fn symlink() -> std::io::Result<()> { + let TestCase { + mut db, + src, + temp_dir: _temp_dir, + .. + } = create_resolver()?; + + let foo = src.path().join("foo.py"); + let bar = src.path().join("bar.py"); + + std::fs::write(&foo, "")?; + std::os::unix::fs::symlink(&foo, &bar)?; + + let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + let bar_module = db.resolve_module(ModuleName::new("bar")).unwrap(); + + assert_ne!(foo_module, bar_module); + + assert_eq!(&src, foo_module.path(&db).root()); + assert_eq!(&foo, &*db.file_path(foo_module.path(&db).file())); + + // Bar has a different name but it should point to the same file. + + assert_eq!(&src, bar_module.path(&db).root()); + assert_eq!(foo_module.path(&db).file(), bar_module.path(&db).file()); + assert_eq!(&foo, &*db.file_path(bar_module.path(&db).file())); + + assert_eq!(Some(foo_module), db.path_to_module(&foo)); + assert_eq!(Some(bar_module), db.path_to_module(&bar)); + + Ok(()) + } +} diff --git a/crates/red_knot/src/parse.rs b/crates/red_knot/src/parse.rs new file mode 100644 index 0000000000..641181fb93 --- /dev/null +++ b/crates/red_knot/src/parse.rs @@ -0,0 +1,95 @@ +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +use ruff_python_ast as ast; +use ruff_python_parser::{Mode, ParseError}; +use ruff_text_size::{Ranged, TextRange}; + +use crate::cache::KeyValueCache; +use crate::db::{HasJar, SourceDb, SourceJar}; +use crate::files::FileId; + +#[derive(Debug, Clone, PartialEq)] +pub struct Parsed { + inner: Arc, +} + +#[derive(Debug, PartialEq)] +struct ParsedInner { + ast: ast::ModModule, + errors: Vec, +} + +impl Parsed { + fn new(ast: ast::ModModule, errors: Vec) -> Self { + Self { + inner: Arc::new(ParsedInner { ast, errors }), + } + } + + pub(crate) fn from_text(text: &str) -> Self { + let result = ruff_python_parser::parse(text, Mode::Module); + + let (module, errors) = match result { + Ok(ast::Mod::Module(module)) => (module, vec![]), + Ok(ast::Mod::Expression(expression)) => ( + ast::ModModule { + range: expression.range(), + body: vec![ast::Stmt::Expr(ast::StmtExpr { + range: expression.range(), + value: expression.body, + })], + }, + vec![], + ), + Err(errors) => ( + ast::ModModule { + range: TextRange::default(), + body: Vec::new(), + }, + vec![errors], + ), + }; + + Parsed::new(module, errors) + } + + pub fn ast(&self) -> &ast::ModModule { + &self.inner.ast + } + + pub fn errors(&self) -> &[ParseError] { + &self.inner.errors + } +} + +#[tracing::instrument(level = "trace", skip(db))] +pub(crate) fn parse(db: &Db, file_id: FileId) -> Parsed +where + Db: SourceDb + HasJar, +{ + let parsed = db.jar(); + + parsed.parsed.get(&file_id, |file_id| { + let source = db.source(*file_id); + + Parsed::from_text(source.text()) + }) +} + +#[derive(Debug, Default)] +pub struct ParsedStorage(KeyValueCache); + +impl Deref for ParsedStorage { + type Target = KeyValueCache; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ParsedStorage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/crates/red_knot/src/program/mod.rs b/crates/red_knot/src/program/mod.rs new file mode 100644 index 0000000000..e9055938dc --- /dev/null +++ b/crates/red_knot/src/program/mod.rs @@ -0,0 +1,154 @@ +use std::path::Path; +use std::sync::Arc; + +use crate::db::{Db, HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar}; +use crate::files::{FileId, Files}; +use crate::lint::{lint_syntax, Diagnostics, LintSyntaxStorage}; +use crate::module::{ + add_module, path_to_module, resolve_module, set_module_search_paths, Module, ModuleData, + ModuleName, ModuleResolver, ModuleSearchPath, +}; +use crate::parse::{parse, Parsed, ParsedStorage}; +use crate::source::{source_text, Source, SourceStorage}; +use crate::symbols::{symbol_table, SymbolId, SymbolTable, SymbolTablesStorage}; +use crate::types::{infer_symbol_type, Type, TypeStore}; + +#[derive(Debug)] +pub struct Program { + files: Files, + source: SourceJar, + semantic: SemanticJar, +} + +impl Program { + pub fn new(module_search_paths: Vec, files: Files) -> Self { + Self { + source: SourceJar { + sources: SourceStorage::default(), + parsed: ParsedStorage::default(), + lint_syntax: LintSyntaxStorage::default(), + }, + semantic: SemanticJar { + module_resolver: ModuleResolver::new(module_search_paths), + symbol_tables: SymbolTablesStorage::default(), + type_store: TypeStore::default(), + }, + files, + } + } + + pub fn apply_changes(&mut self, changes: I) + where + I: IntoIterator, + { + for change in changes { + self.semantic + .module_resolver + .remove_module(&self.file_path(change.id)); + self.semantic.symbol_tables.remove(&change.id); + self.source.sources.remove(&change.id); + self.source.parsed.remove(&change.id); + self.source.lint_syntax.remove(&change.id); + // TODO: remove all dependent modules as well + self.semantic.type_store.remove_module(change.id); + } + } +} + +impl SourceDb for Program { + fn file_id(&self, path: &Path) -> FileId { + self.files.intern(path) + } + + fn file_path(&self, file_id: FileId) -> Arc { + self.files.path(file_id) + } + + fn source(&self, file_id: FileId) -> Source { + source_text(self, file_id) + } + + fn parse(&self, file_id: FileId) -> Parsed { + parse(self, file_id) + } + + fn lint_syntax(&self, file_id: FileId) -> Diagnostics { + lint_syntax(self, file_id) + } +} + +impl SemanticDb for Program { + fn resolve_module(&self, name: ModuleName) -> Option { + resolve_module(self, name) + } + + fn symbol_table(&self, file_id: FileId) -> Arc { + symbol_table(self, file_id) + } + + // Mutations + fn path_to_module(&mut self, path: &Path) -> Option { + path_to_module(self, path) + } + + fn add_module(&mut self, path: &Path) -> Option<(Module, Vec>)> { + add_module(self, path) + } + + fn set_module_search_paths(&mut self, paths: Vec) { + set_module_search_paths(self, paths); + } + + fn infer_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId) -> Type { + infer_symbol_type(self, file_id, symbol_id) + } +} + +impl Db for Program {} + +impl HasJar for Program { + fn jar(&self) -> &SourceJar { + &self.source + } + + fn jar_mut(&mut self) -> &mut SourceJar { + &mut self.source + } +} + +impl HasJar for Program { + fn jar(&self) -> &SemanticJar { + &self.semantic + } + + fn jar_mut(&mut self) -> &mut SemanticJar { + &mut self.semantic + } +} + +#[derive(Copy, Clone, Debug)] +pub struct FileChange { + id: FileId, + kind: FileChangeKind, +} + +impl FileChange { + pub fn new(file_id: FileId, kind: FileChangeKind) -> Self { + Self { id: file_id, kind } + } + + pub fn file_id(&self) -> FileId { + self.id + } + + pub fn kind(&self) -> FileChangeKind { + self.kind + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum FileChangeKind { + Created, + Modified, + Deleted, +} diff --git a/crates/red_knot/src/source.rs b/crates/red_knot/src/source.rs new file mode 100644 index 0000000000..7dd6ed9285 --- /dev/null +++ b/crates/red_knot/src/source.rs @@ -0,0 +1,98 @@ +use crate::cache::KeyValueCache; +use crate::db::{HasJar, SourceDb, SourceJar}; +use ruff_notebook::Notebook; +use ruff_python_ast::PySourceType; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +use crate::files::FileId; + +#[tracing::instrument(level = "trace", skip(db))] +pub(crate) fn source_text(db: &Db, file_id: FileId) -> Source +where + Db: SourceDb + HasJar, +{ + let sources = &db.jar().sources; + + sources.get(&file_id, |file_id| { + tracing::trace!("Reading source text for file_id={:?}.", file_id); + + let path = db.file_path(*file_id); + + let source_text = std::fs::read_to_string(&path).unwrap_or_else(|err| { + tracing::error!("Failed to read file '{path:?}: {err}'. Falling back to empty text"); + String::new() + }); + + let python_ty = PySourceType::from(&path); + + let kind = match python_ty { + PySourceType::Python => { + SourceKind::Python(Arc::from(source_text)) + } + PySourceType::Stub => SourceKind::Stub(Arc::from(source_text)), + PySourceType::Ipynb => { + let notebook = Notebook::from_source_code(&source_text).unwrap_or_else(|err| { + // TODO should this be changed to never fail? + // or should we instead add a diagnostic somewhere? But what would we return in this case? + tracing::error!( + "Failed to parse notebook '{path:?}: {err}'. Falling back to an empty notebook" + ); + Notebook::from_source_code("").unwrap() + }); + + SourceKind::IpyNotebook(Arc::new(notebook)) + } + }; + + Source { kind } + }) +} + +#[derive(Debug, Clone, PartialEq)] +pub enum SourceKind { + Python(Arc), + Stub(Arc), + IpyNotebook(Arc), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Source { + kind: SourceKind, +} + +impl Source { + pub fn python>>(source: T) -> Self { + Self { + kind: SourceKind::Python(source.into()), + } + } + pub fn kind(&self) -> &SourceKind { + &self.kind + } + + pub fn text(&self) -> &str { + match &self.kind { + SourceKind::Python(text) => text, + SourceKind::Stub(text) => text, + SourceKind::IpyNotebook(notebook) => notebook.source_code(), + } + } +} + +#[derive(Debug, Default)] +pub struct SourceStorage(pub(crate) KeyValueCache); + +impl Deref for SourceStorage { + type Target = KeyValueCache; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SourceStorage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs new file mode 100644 index 0000000000..96fa6b120f --- /dev/null +++ b/crates/red_knot/src/symbols.rs @@ -0,0 +1,765 @@ +#![allow(dead_code)] + +use std::hash::{Hash, Hasher}; +use std::iter::{Copied, DoubleEndedIterator, FusedIterator}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +use hashbrown::hash_map::{Keys, RawEntryMut}; +use rustc_hash::{FxHashMap, FxHasher}; + +use ruff_index::{newtype_index, IndexVec}; +use ruff_python_ast as ast; +use ruff_python_ast::visitor::preorder::PreorderVisitor; + +use crate::ast_ids::TypedNodeKey; +use crate::cache::KeyValueCache; +use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::files::FileId; +use crate::Name; + +#[allow(unreachable_pub)] +#[tracing::instrument(level = "trace", skip(db))] +pub fn symbol_table(db: &Db, file_id: FileId) -> Arc +where + Db: SemanticDb + HasJar, +{ + let jar = db.jar(); + + jar.symbol_tables.get(&file_id, |_| { + let parsed = db.parse(file_id); + Arc::from(SymbolTable::from_ast(parsed.ast())) + }) +} + +type Map = hashbrown::HashMap; + +#[newtype_index] +pub(crate) struct ScopeId; + +impl ScopeId { + pub(crate) fn scope(self, table: &SymbolTable) -> &Scope { + &table.scopes_by_id[self] + } +} + +#[newtype_index] +pub struct SymbolId; + +impl SymbolId { + pub(crate) fn symbol(self, table: &SymbolTable) -> &Symbol { + &table.symbols_by_id[self] + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum ScopeKind { + Module, + Annotation, + Class, + Function, +} + +#[derive(Debug)] +pub(crate) struct Scope { + name: Name, + kind: ScopeKind, + child_scopes: Vec, + // symbol IDs, hashed by symbol name + symbols_by_name: Map, +} + +impl Scope { + pub(crate) fn name(&self) -> &str { + self.name.as_str() + } + + pub(crate) fn kind(&self) -> ScopeKind { + self.kind + } +} + +#[derive(Debug)] +pub(crate) struct Symbol { + name: Name, +} + +impl Symbol { + pub(crate) fn name(&self) -> &str { + self.name.as_str() + } +} + +// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST; +// this is at best O(log n). If looking up definitions is a bottleneck we should look for +// alternatives here. +#[derive(Debug)] +pub(crate) enum Definition { + // For the import cases, we don't need reference to any arbitrary AST subtrees (annotations, + // RHS), and referencing just the import statement node is imprecise (a single import statement + // can assign many symbols, we'd have to re-search for the one we care about), so we just copy + // the small amount of information we need from the AST. + Import(ImportDefinition), + ImportFrom(ImportFromDefinition), + ClassDef(TypedNodeKey), + FunctionDef(TypedNodeKey), + Assignment(TypedNodeKey), + AnnotatedAssignment(TypedNodeKey), + // TODO with statements, except handlers, function args... +} + +#[derive(Debug)] +pub(crate) struct ImportDefinition { + pub(crate) module: String, +} + +#[derive(Debug)] +pub(crate) struct ImportFromDefinition { + pub(crate) module: Option, + pub(crate) name: String, + pub(crate) level: u32, +} + +/// Table of all symbols in all scopes for a module. +#[derive(Debug)] +pub struct SymbolTable { + scopes_by_id: IndexVec, + symbols_by_id: IndexVec, + defs: FxHashMap>, +} + +impl SymbolTable { + pub(crate) fn from_ast(module: &ast::ModModule) -> Self { + let root_scope_id = SymbolTable::root_scope_id(); + let mut builder = SymbolTableBuilder { + table: SymbolTable::new(), + scopes: vec![root_scope_id], + }; + builder.visit_body(&module.body); + builder.table + } + + pub(crate) fn new() -> Self { + let mut table = SymbolTable { + scopes_by_id: IndexVec::new(), + symbols_by_id: IndexVec::new(), + defs: FxHashMap::default(), + }; + table.scopes_by_id.push(Scope { + name: Name::new(""), + kind: ScopeKind::Module, + child_scopes: Vec::new(), + symbols_by_name: Map::default(), + }); + table + } + + pub(crate) const fn root_scope_id() -> ScopeId { + ScopeId::from_usize(0) + } + + pub(crate) fn root_scope(&self) -> &Scope { + &self.scopes_by_id[SymbolTable::root_scope_id()] + } + + pub(crate) fn symbol_ids_for_scope(&self, scope_id: ScopeId) -> Copied> { + self.scopes_by_id[scope_id].symbols_by_name.keys().copied() + } + + pub(crate) fn symbols_for_scope( + &self, + scope_id: ScopeId, + ) -> SymbolIterator>> { + SymbolIterator { + table: self, + ids: self.symbol_ids_for_scope(scope_id), + } + } + + pub(crate) fn root_symbol_ids(&self) -> Copied> { + self.symbol_ids_for_scope(SymbolTable::root_scope_id()) + } + + pub(crate) fn root_symbols(&self) -> SymbolIterator>> { + self.symbols_for_scope(SymbolTable::root_scope_id()) + } + + pub(crate) fn child_scope_ids_of(&self, scope_id: ScopeId) -> &[ScopeId] { + &self.scopes_by_id[scope_id].child_scopes + } + + pub(crate) fn child_scopes_of(&self, scope_id: ScopeId) -> ScopeIterator<&[ScopeId]> { + ScopeIterator { + table: self, + ids: self.child_scope_ids_of(scope_id), + } + } + + pub(crate) fn root_child_scope_ids(&self) -> &[ScopeId] { + self.child_scope_ids_of(SymbolTable::root_scope_id()) + } + + pub(crate) fn root_child_scopes(&self) -> ScopeIterator<&[ScopeId]> { + self.child_scopes_of(SymbolTable::root_scope_id()) + } + + pub(crate) fn symbol_id_by_name(&self, scope_id: ScopeId, name: &str) -> Option { + let scope = &self.scopes_by_id[scope_id]; + let hash = SymbolTable::hash_name(name); + let name = Name::new(name); + scope + .symbols_by_name + .raw_entry() + .from_hash(hash, |symid| self.symbols_by_id[*symid].name == name) + .map(|(symbol_id, ())| *symbol_id) + } + + pub(crate) fn symbol_by_name(&self, scope_id: ScopeId, name: &str) -> Option<&Symbol> { + Some(&self.symbols_by_id[self.symbol_id_by_name(scope_id, name)?]) + } + + pub(crate) fn root_symbol_id_by_name(&self, name: &str) -> Option { + self.symbol_id_by_name(SymbolTable::root_scope_id(), name) + } + + pub(crate) fn root_symbol_by_name(&self, name: &str) -> Option<&Symbol> { + self.symbol_by_name(SymbolTable::root_scope_id(), name) + } + + pub(crate) fn defs(&self, symbol_id: SymbolId) -> &[Definition] { + self.defs + .get(&symbol_id) + .map(std::vec::Vec::as_slice) + .unwrap_or_default() + } + + fn add_symbol_to_scope(&mut self, scope_id: ScopeId, name: &str) -> SymbolId { + let hash = SymbolTable::hash_name(name); + let scope = &mut self.scopes_by_id[scope_id]; + let name = Name::new(name); + + let entry = scope + .symbols_by_name + .raw_entry_mut() + .from_hash(hash, |existing| self.symbols_by_id[*existing].name == name); + + match entry { + RawEntryMut::Occupied(entry) => *entry.key(), + RawEntryMut::Vacant(entry) => { + let id = self.symbols_by_id.push(Symbol { name }); + entry.insert_with_hasher(hash, id, (), |_| hash); + id + } + } + } + + fn add_child_scope( + &mut self, + parent_scope_id: ScopeId, + name: &str, + kind: ScopeKind, + ) -> ScopeId { + let new_scope_id = self.scopes_by_id.push(Scope { + name: Name::new(name), + kind, + child_scopes: Vec::new(), + symbols_by_name: Map::default(), + }); + let parent_scope = &mut self.scopes_by_id[parent_scope_id]; + parent_scope.child_scopes.push(new_scope_id); + new_scope_id + } + + fn hash_name(name: &str) -> u64 { + let mut hasher = FxHasher::default(); + name.hash(&mut hasher); + hasher.finish() + } +} + +pub(crate) struct SymbolIterator<'a, I> { + table: &'a SymbolTable, + ids: I, +} + +impl<'a, I> Iterator for SymbolIterator<'a, I> +where + I: Iterator, +{ + type Item = &'a Symbol; + + fn next(&mut self) -> Option { + let id = self.ids.next()?; + Some(&self.table.symbols_by_id[id]) + } + + fn size_hint(&self) -> (usize, Option) { + self.ids.size_hint() + } +} + +impl<'a, I> FusedIterator for SymbolIterator<'a, I> where + I: Iterator + FusedIterator +{ +} + +impl<'a, I> DoubleEndedIterator for SymbolIterator<'a, I> +where + I: Iterator + DoubleEndedIterator, +{ + fn next_back(&mut self) -> Option { + let id = self.ids.next_back()?; + Some(&self.table.symbols_by_id[id]) + } +} + +pub(crate) struct ScopeIterator<'a, I> { + table: &'a SymbolTable, + ids: I, +} + +impl<'a, I> Iterator for ScopeIterator<'a, I> +where + I: Iterator, +{ + type Item = &'a Scope; + + fn next(&mut self) -> Option { + let id = self.ids.next()?; + Some(&self.table.scopes_by_id[id]) + } + + fn size_hint(&self) -> (usize, Option) { + self.ids.size_hint() + } +} + +impl<'a, I> FusedIterator for ScopeIterator<'a, I> where I: Iterator + FusedIterator {} + +impl<'a, I> DoubleEndedIterator for ScopeIterator<'a, I> +where + I: Iterator + DoubleEndedIterator, +{ + fn next_back(&mut self) -> Option { + let id = self.ids.next_back()?; + Some(&self.table.scopes_by_id[id]) + } +} + +struct SymbolTableBuilder { + table: SymbolTable, + scopes: Vec, +} + +impl SymbolTableBuilder { + fn add_symbol(&mut self, identifier: &str) -> SymbolId { + self.table.add_symbol_to_scope(self.cur_scope(), identifier) + } + + fn add_symbol_with_def(&mut self, identifier: &str, definition: Definition) -> SymbolId { + let symbol_id = self.add_symbol(identifier); + self.table + .defs + .entry(symbol_id) + .or_default() + .push(definition); + symbol_id + } + + fn push_scope(&mut self, child_of: ScopeId, name: &str, kind: ScopeKind) -> ScopeId { + let scope_id = self.table.add_child_scope(child_of, name, kind); + self.scopes.push(scope_id); + scope_id + } + + fn pop_scope(&mut self) -> ScopeId { + self.scopes + .pop() + .expect("Scope stack should never be empty") + } + + fn cur_scope(&self) -> ScopeId { + *self + .scopes + .last() + .expect("Scope stack should never be empty") + } + + fn with_type_params( + &mut self, + name: &str, + params: &Option>, + nested: impl FnOnce(&mut Self), + ) { + if let Some(type_params) = params { + self.push_scope(self.cur_scope(), name, ScopeKind::Annotation); + for type_param in &type_params.type_params { + let name = match type_param { + ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, + ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name, + ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name, + }; + self.add_symbol(name); + } + } + nested(self); + if params.is_some() { + self.pop_scope(); + } + } +} + +impl PreorderVisitor<'_> for SymbolTableBuilder { + fn visit_expr(&mut self, expr: &ast::Expr) { + if let ast::Expr::Name(ast::ExprName { id, .. }) = expr { + self.add_symbol(id); + } + ast::visitor::preorder::walk_expr(self, expr); + } + + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + // TODO need to capture more definition statements here + match stmt { + ast::Stmt::ClassDef(node) => { + let def = Definition::ClassDef(TypedNodeKey::from_node(node)); + self.add_symbol_with_def(&node.name, def); + self.with_type_params(&node.name, &node.type_params, |builder| { + builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class); + ast::visitor::preorder::walk_stmt(builder, stmt); + builder.pop_scope(); + }); + } + ast::Stmt::FunctionDef(node) => { + let def = Definition::FunctionDef(TypedNodeKey::from_node(node)); + self.add_symbol_with_def(&node.name, def); + self.with_type_params(&node.name, &node.type_params, |builder| { + builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function); + ast::visitor::preorder::walk_stmt(builder, stmt); + builder.pop_scope(); + }); + } + ast::Stmt::Import(ast::StmtImport { names, .. }) => { + for alias in names { + let symbol_name = if let Some(asname) = &alias.asname { + asname.id.as_str() + } else { + alias.name.id.split('.').next().unwrap() + }; + let def = Definition::Import(ImportDefinition { + module: alias.name.id.clone(), + }); + self.add_symbol_with_def(symbol_name, def); + } + } + ast::Stmt::ImportFrom(ast::StmtImportFrom { + module, + names, + level, + .. + }) => { + for alias in names { + let symbol_name = if let Some(asname) = &alias.asname { + asname.id.as_str() + } else { + alias.name.id.as_str() + }; + let def = Definition::ImportFrom(ImportFromDefinition { + module: module.as_ref().map(|m| m.id.clone()), + name: alias.name.id.clone(), + level: *level, + }); + self.add_symbol_with_def(symbol_name, def); + } + } + _ => { + ast::visitor::preorder::walk_stmt(self, stmt); + } + } + } +} + +#[derive(Debug, Default)] +pub struct SymbolTablesStorage(KeyValueCache>); + +impl Deref for SymbolTablesStorage { + type Target = KeyValueCache>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SymbolTablesStorage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[cfg(test)] +mod tests { + use textwrap::dedent; + + use crate::parse::Parsed; + use crate::symbols::ScopeKind; + + use super::{SymbolId, SymbolIterator, SymbolTable}; + + mod from_ast { + use super::*; + + fn parse(code: &str) -> Parsed { + Parsed::from_text(&dedent(code)) + } + + fn names(it: SymbolIterator) -> Vec<&str> + where + I: Iterator, + { + let mut symbols: Vec<_> = it.map(|sym| sym.name.as_str()).collect(); + symbols.sort_unstable(); + symbols + } + + #[test] + fn empty() { + let parsed = parse(""); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()).len(), 0); + } + + #[test] + fn simple() { + let parsed = parse("x"); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["x"]); + assert_eq!( + table.defs(table.root_symbol_id_by_name("x").unwrap()).len(), + 0 + ); + } + + #[test] + fn annotation_only() { + let parsed = parse("x: int"); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["int", "x"]); + // TODO record definition + } + + #[test] + fn import() { + let parsed = parse("import foo"); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["foo"]); + assert_eq!( + table + .defs(table.root_symbol_id_by_name("foo").unwrap()) + .len(), + 1 + ); + } + + #[test] + fn import_sub() { + let parsed = parse("import foo.bar"); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["foo"]); + } + + #[test] + fn import_as() { + let parsed = parse("import foo.bar as baz"); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["baz"]); + } + + #[test] + fn import_from() { + let parsed = parse("from bar import foo"); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["foo"]); + assert_eq!( + table + .defs(table.root_symbol_id_by_name("foo").unwrap()) + .len(), + 1 + ); + } + + #[test] + fn class_scope() { + let parsed = parse( + " + class C: + x = 1 + y = 2 + ", + ); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["C", "y"]); + let scopes = table.root_child_scope_ids(); + assert_eq!(scopes.len(), 1); + let c_scope = scopes[0].scope(&table); + assert_eq!(c_scope.kind(), ScopeKind::Class); + assert_eq!(c_scope.name(), "C"); + assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); + assert_eq!( + table.defs(table.root_symbol_id_by_name("C").unwrap()).len(), + 1 + ); + } + + #[test] + fn func_scope() { + let parsed = parse( + " + def func(): + x = 1 + y = 2 + ", + ); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["func", "y"]); + let scopes = table.root_child_scope_ids(); + assert_eq!(scopes.len(), 1); + let func_scope = scopes[0].scope(&table); + assert_eq!(func_scope.kind(), ScopeKind::Function); + assert_eq!(func_scope.name(), "func"); + assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); + assert_eq!( + table + .defs(table.root_symbol_id_by_name("func").unwrap()) + .len(), + 1 + ); + } + + #[test] + fn dupes() { + let parsed = parse( + " + def func(): + x = 1 + def func(): + y = 2 + ", + ); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["func"]); + let scopes = table.root_child_scope_ids(); + assert_eq!(scopes.len(), 2); + let func_scope_1 = scopes[0].scope(&table); + let func_scope_2 = scopes[1].scope(&table); + assert_eq!(func_scope_1.kind(), ScopeKind::Function); + assert_eq!(func_scope_1.name(), "func"); + assert_eq!(func_scope_2.kind(), ScopeKind::Function); + assert_eq!(func_scope_2.name(), "func"); + assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); + assert_eq!(names(table.symbols_for_scope(scopes[1])), vec!["y"]); + assert_eq!( + table + .defs(table.root_symbol_id_by_name("func").unwrap()) + .len(), + 2 + ); + } + + #[test] + fn generic_func() { + let parsed = parse( + " + def func[T](): + x = 1 + ", + ); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["func"]); + let scopes = table.root_child_scope_ids(); + assert_eq!(scopes.len(), 1); + let ann_scope_id = scopes[0]; + let ann_scope = ann_scope_id.scope(&table); + assert_eq!(ann_scope.kind(), ScopeKind::Annotation); + assert_eq!(ann_scope.name(), "func"); + assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]); + let scopes = table.child_scope_ids_of(ann_scope_id); + assert_eq!(scopes.len(), 1); + let func_scope_id = scopes[0]; + let func_scope = func_scope_id.scope(&table); + assert_eq!(func_scope.kind(), ScopeKind::Function); + assert_eq!(func_scope.name(), "func"); + assert_eq!(names(table.symbols_for_scope(func_scope_id)), vec!["x"]); + } + + #[test] + fn generic_class() { + let parsed = parse( + " + class C[T]: + x = 1 + ", + ); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["C"]); + let scopes = table.root_child_scope_ids(); + assert_eq!(scopes.len(), 1); + let ann_scope_id = scopes[0]; + let ann_scope = ann_scope_id.scope(&table); + assert_eq!(ann_scope.kind(), ScopeKind::Annotation); + assert_eq!(ann_scope.name(), "C"); + assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]); + let scopes = table.child_scope_ids_of(ann_scope_id); + assert_eq!(scopes.len(), 1); + let func_scope_id = scopes[0]; + let func_scope = func_scope_id.scope(&table); + assert_eq!(func_scope.kind(), ScopeKind::Class); + assert_eq!(func_scope.name(), "C"); + assert_eq!(names(table.symbols_for_scope(func_scope_id)), vec!["x"]); + } + } + + #[test] + fn insert_same_name_symbol_twice() { + let mut table = SymbolTable::new(); + let root_scope_id = SymbolTable::root_scope_id(); + let symbol_id_1 = table.add_symbol_to_scope(root_scope_id, "foo"); + let symbol_id_2 = table.add_symbol_to_scope(root_scope_id, "foo"); + assert_eq!(symbol_id_1, symbol_id_2); + } + + #[test] + fn insert_different_named_symbols() { + let mut table = SymbolTable::new(); + let root_scope_id = SymbolTable::root_scope_id(); + let symbol_id_1 = table.add_symbol_to_scope(root_scope_id, "foo"); + let symbol_id_2 = table.add_symbol_to_scope(root_scope_id, "bar"); + assert_ne!(symbol_id_1, symbol_id_2); + } + + #[test] + fn add_child_scope_with_symbol() { + let mut table = SymbolTable::new(); + let root_scope_id = SymbolTable::root_scope_id(); + let foo_symbol_top = table.add_symbol_to_scope(root_scope_id, "foo"); + let c_scope = table.add_child_scope(root_scope_id, "C", ScopeKind::Class); + let foo_symbol_inner = table.add_symbol_to_scope(c_scope, "foo"); + assert_ne!(foo_symbol_top, foo_symbol_inner); + } + + #[test] + fn scope_from_id() { + let table = SymbolTable::new(); + let root_scope_id = SymbolTable::root_scope_id(); + let scope = root_scope_id.scope(&table); + assert_eq!(scope.name.as_str(), ""); + assert_eq!(scope.kind, ScopeKind::Module); + } + + #[test] + fn symbol_from_id() { + let mut table = SymbolTable::new(); + let root_scope_id = SymbolTable::root_scope_id(); + let foo_symbol_id = table.add_symbol_to_scope(root_scope_id, "foo"); + let symbol = foo_symbol_id.symbol(&table); + assert_eq!(symbol.name.as_str(), "foo"); + } +} diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs new file mode 100644 index 0000000000..7c084445fc --- /dev/null +++ b/crates/red_knot/src/types.rs @@ -0,0 +1,519 @@ +#![allow(dead_code)] +use crate::ast_ids::NodeKey; +use crate::files::FileId; +use crate::symbols::SymbolId; +use crate::{FxDashMap, FxIndexSet, Name}; +use ruff_index::{newtype_index, IndexVec}; +use rustc_hash::FxHashMap; + +pub(crate) mod infer; + +pub(crate) use infer::infer_symbol_type; + +/// unique ID for a type +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Type { + /// the dynamic or gradual type: a statically-unknown set of values + Any, + /// the empty set of values + Never, + /// unknown type (no annotation) + /// equivalent to Any, or to object in strict mode + Unknown, + /// name is not bound to any value + Unbound, + /// a specific function + Function(FunctionTypeId), + /// the set of Python objects with a given class in their __class__'s method resolution order + Class(ClassTypeId), + Union(UnionTypeId), + Intersection(IntersectionTypeId), + // TODO protocols, callable types, overloads, generics, type vars +} + +impl Type { + fn display<'a>(&'a self, store: &'a TypeStore) -> DisplayType<'a> { + DisplayType { ty: self, store } + } +} + +// TODO: currently calling `get_function` et al and holding on to the `FunctionTypeRef` will lock a +// shard of this dashmap, for as long as you hold the reference. This may be a problem. We could +// switch to having all the arenas hold Arc, or we could see if we can split up ModuleTypeStore, +// and/or give it inner mutability and finer-grained internal locking. +#[derive(Debug, Default)] +pub struct TypeStore { + modules: FxDashMap, +} + +impl TypeStore { + pub fn remove_module(&mut self, file_id: FileId) { + self.modules.remove(&file_id); + } + + pub fn cache_symbol_type(&mut self, file_id: FileId, symbol_id: SymbolId, ty: Type) { + self.add_or_get_module(file_id) + .symbol_types + .insert(symbol_id, ty); + } + + pub fn cache_node_type(&mut self, file_id: FileId, node_key: NodeKey, ty: Type) { + self.add_or_get_module(file_id) + .node_types + .insert(node_key, ty); + } + + pub fn get_cached_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Option { + self.try_get_module(file_id)? + .symbol_types + .get(&symbol_id) + .copied() + } + + pub fn get_cached_node_type(&self, file_id: FileId, node_key: &NodeKey) -> Option { + self.try_get_module(file_id)? + .node_types + .get(node_key) + .copied() + } + + fn add_or_get_module(&mut self, file_id: FileId) -> ModuleStoreRefMut { + self.modules + .entry(file_id) + .or_insert_with(|| ModuleTypeStore::new(file_id)) + } + + fn get_module(&self, file_id: FileId) -> ModuleStoreRef { + self.try_get_module(file_id).expect("module should exist") + } + + fn try_get_module(&self, file_id: FileId) -> Option { + self.modules.get(&file_id) + } + + fn add_function(&mut self, file_id: FileId, name: &str) -> FunctionTypeId { + self.add_or_get_module(file_id).add_function(name) + } + + fn add_class(&mut self, file_id: FileId, name: &str) -> ClassTypeId { + self.add_or_get_module(file_id).add_class(name) + } + + fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId { + self.add_or_get_module(file_id).add_union(elems) + } + + fn add_intersection( + &mut self, + file_id: FileId, + positive: &[Type], + negative: &[Type], + ) -> IntersectionTypeId { + self.add_or_get_module(file_id) + .add_intersection(positive, negative) + } + + fn get_function(&self, id: FunctionTypeId) -> FunctionTypeRef { + FunctionTypeRef { + module_store: self.get_module(id.file_id), + function_id: id.func_id, + } + } + + fn get_class(&self, id: ClassTypeId) -> ClassTypeRef { + ClassTypeRef { + module_store: self.get_module(id.file_id), + class_id: id.class_id, + } + } + + fn get_union(&self, id: UnionTypeId) -> UnionTypeRef { + UnionTypeRef { + module_store: self.get_module(id.file_id), + union_id: id.union_id, + } + } + + fn get_intersection(&self, id: IntersectionTypeId) -> IntersectionTypeRef { + IntersectionTypeRef { + module_store: self.get_module(id.file_id), + intersection_id: id.intersection_id, + } + } +} + +type ModuleStoreRef<'a> = dashmap::mapref::one::Ref< + 'a, + FileId, + ModuleTypeStore, + std::hash::BuildHasherDefault, +>; + +type ModuleStoreRefMut<'a> = dashmap::mapref::one::RefMut< + 'a, + FileId, + ModuleTypeStore, + std::hash::BuildHasherDefault, +>; + +#[derive(Debug)] +pub(crate) struct FunctionTypeRef<'a> { + module_store: ModuleStoreRef<'a>, + function_id: ModuleFunctionTypeId, +} + +impl<'a> std::ops::Deref for FunctionTypeRef<'a> { + type Target = FunctionType; + + fn deref(&self) -> &Self::Target { + self.module_store.get_function(self.function_id) + } +} + +#[derive(Debug)] +pub(crate) struct ClassTypeRef<'a> { + module_store: ModuleStoreRef<'a>, + class_id: ModuleClassTypeId, +} + +impl<'a> std::ops::Deref for ClassTypeRef<'a> { + type Target = ClassType; + + fn deref(&self) -> &Self::Target { + self.module_store.get_class(self.class_id) + } +} + +#[derive(Debug)] +pub(crate) struct UnionTypeRef<'a> { + module_store: ModuleStoreRef<'a>, + union_id: ModuleUnionTypeId, +} + +impl<'a> std::ops::Deref for UnionTypeRef<'a> { + type Target = UnionType; + + fn deref(&self) -> &Self::Target { + self.module_store.get_union(self.union_id) + } +} + +#[derive(Debug)] +pub(crate) struct IntersectionTypeRef<'a> { + module_store: ModuleStoreRef<'a>, + intersection_id: ModuleIntersectionTypeId, +} + +impl<'a> std::ops::Deref for IntersectionTypeRef<'a> { + type Target = IntersectionType; + + fn deref(&self) -> &Self::Target { + self.module_store.get_intersection(self.intersection_id) + } +} + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +pub struct FunctionTypeId { + file_id: FileId, + func_id: ModuleFunctionTypeId, +} + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +pub struct ClassTypeId { + file_id: FileId, + class_id: ModuleClassTypeId, +} + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +pub struct UnionTypeId { + file_id: FileId, + union_id: ModuleUnionTypeId, +} + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +pub struct IntersectionTypeId { + file_id: FileId, + intersection_id: ModuleIntersectionTypeId, +} + +#[newtype_index] +struct ModuleFunctionTypeId; + +#[newtype_index] +struct ModuleClassTypeId; + +#[newtype_index] +struct ModuleUnionTypeId; + +#[newtype_index] +struct ModuleIntersectionTypeId; + +#[derive(Debug)] +struct ModuleTypeStore { + file_id: FileId, + /// arena of all function types defined in this module + functions: IndexVec, + /// arena of all class types defined in this module + classes: IndexVec, + /// arenda of all union types created in this module + unions: IndexVec, + /// arena of all intersection types created in this module + intersections: IndexVec, + /// cached types of symbols in this module + symbol_types: FxHashMap, + /// cached types of AST nodes in this module + node_types: FxHashMap, +} + +impl ModuleTypeStore { + fn new(file_id: FileId) -> Self { + Self { + file_id, + functions: IndexVec::default(), + classes: IndexVec::default(), + unions: IndexVec::default(), + intersections: IndexVec::default(), + symbol_types: FxHashMap::default(), + node_types: FxHashMap::default(), + } + } + + fn add_function(&mut self, name: &str) -> FunctionTypeId { + let func_id = self.functions.push(FunctionType { + name: Name::new(name), + }); + FunctionTypeId { + file_id: self.file_id, + func_id, + } + } + + fn add_class(&mut self, name: &str) -> ClassTypeId { + let class_id = self.classes.push(ClassType { + name: Name::new(name), + }); + ClassTypeId { + file_id: self.file_id, + class_id, + } + } + + fn add_union(&mut self, elems: &[Type]) -> UnionTypeId { + let union_id = self.unions.push(UnionType { + elements: elems.iter().copied().collect(), + }); + UnionTypeId { + file_id: self.file_id, + union_id, + } + } + + fn add_intersection(&mut self, positive: &[Type], negative: &[Type]) -> IntersectionTypeId { + let intersection_id = self.intersections.push(IntersectionType { + positive: positive.iter().copied().collect(), + negative: negative.iter().copied().collect(), + }); + IntersectionTypeId { + file_id: self.file_id, + intersection_id, + } + } + + fn get_function(&self, func_id: ModuleFunctionTypeId) -> &FunctionType { + &self.functions[func_id] + } + + fn get_class(&self, class_id: ModuleClassTypeId) -> &ClassType { + &self.classes[class_id] + } + + fn get_union(&self, union_id: ModuleUnionTypeId) -> &UnionType { + &self.unions[union_id] + } + + fn get_intersection(&self, intersection_id: ModuleIntersectionTypeId) -> &IntersectionType { + &self.intersections[intersection_id] + } +} + +#[derive(Copy, Clone, Debug)] +struct DisplayType<'a> { + ty: &'a Type, + store: &'a TypeStore, +} + +impl std::fmt::Display for DisplayType<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.ty { + Type::Any => f.write_str("Any"), + Type::Never => f.write_str("Never"), + Type::Unknown => f.write_str("Unknown"), + Type::Unbound => f.write_str("Unbound"), + Type::Class(class_id) => f.write_str(self.store.get_class(*class_id).name()), + Type::Function(func_id) => f.write_str(self.store.get_function(*func_id).name()), + Type::Union(union_id) => self + .store + .get_module(union_id.file_id) + .get_union(union_id.union_id) + .display(f, self.store), + Type::Intersection(int_id) => self + .store + .get_module(int_id.file_id) + .get_intersection(int_id.intersection_id) + .display(f, self.store), + } + } +} + +#[derive(Debug)] +pub(crate) struct ClassType { + name: Name, +} + +impl ClassType { + fn name(&self) -> &str { + self.name.as_str() + } +} + +#[derive(Debug)] +pub(crate) struct FunctionType { + name: Name, +} + +impl FunctionType { + fn name(&self) -> &str { + self.name.as_str() + } +} + +#[derive(Debug)] +pub(crate) struct UnionType { + // the union type includes values in any of these types + elements: FxIndexSet, +} + +impl UnionType { + fn display(&self, f: &mut std::fmt::Formatter<'_>, store: &TypeStore) -> std::fmt::Result { + f.write_str("(")?; + let mut first = true; + for ty in &self.elements { + if !first { + f.write_str(" | ")?; + }; + first = false; + write!(f, "{}", ty.display(store))?; + } + f.write_str(")") + } +} + +// Negation types aren't expressible in annotations, and are most likely to arise from type +// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them +// directly in intersections rather than as a separate type. This sacrifices some efficiency in the +// case where a Not appears outside an intersection (unclear when that could even happen, but we'd +// have to represent it as a single-element intersection if it did) in exchange for better +// efficiency in the not-within-intersection case. +#[derive(Debug)] +pub(crate) struct IntersectionType { + // the intersection type includes only values in all of these types + positive: FxIndexSet, + // negated elements of the intersection, e.g. + negative: FxIndexSet, +} + +impl IntersectionType { + fn display(&self, f: &mut std::fmt::Formatter<'_>, store: &TypeStore) -> std::fmt::Result { + f.write_str("(")?; + let mut first = true; + for (neg, ty) in self + .positive + .iter() + .map(|ty| (false, ty)) + .chain(self.negative.iter().map(|ty| (true, ty))) + { + if !first { + f.write_str(" & ")?; + }; + first = false; + if neg { + f.write_str("~")?; + }; + write!(f, "{}", ty.display(store))?; + } + f.write_str(")") + } +} + +#[cfg(test)] +mod tests { + use crate::files::Files; + use crate::types::{Type, TypeStore}; + use crate::FxIndexSet; + use std::path::Path; + + #[test] + fn add_class() { + let mut store = TypeStore::default(); + let files = Files::default(); + let file_id = files.intern(Path::new("/foo")); + let id = store.add_class(file_id, "C"); + assert_eq!(store.get_class(id).name(), "C"); + let class = Type::Class(id); + assert_eq!(format!("{}", class.display(&store)), "C"); + } + + #[test] + fn add_function() { + let mut store = TypeStore::default(); + let files = Files::default(); + let file_id = files.intern(Path::new("/foo")); + let id = store.add_function(file_id, "func"); + assert_eq!(store.get_function(id).name(), "func"); + let func = Type::Function(id); + assert_eq!(format!("{}", func.display(&store)), "func"); + } + + #[test] + fn add_union() { + let mut store = TypeStore::default(); + let files = Files::default(); + let file_id = files.intern(Path::new("/foo")); + let c1 = store.add_class(file_id, "C1"); + let c2 = store.add_class(file_id, "C2"); + let elems = vec![Type::Class(c1), Type::Class(c2)]; + let id = store.add_union(file_id, &elems); + assert_eq!( + store.get_union(id).elements, + elems.into_iter().collect::>() + ); + let union = Type::Union(id); + assert_eq!(format!("{}", union.display(&store)), "(C1 | C2)"); + } + + #[test] + fn add_intersection() { + let mut store = TypeStore::default(); + let files = Files::default(); + let file_id = files.intern(Path::new("/foo")); + let c1 = store.add_class(file_id, "C1"); + let c2 = store.add_class(file_id, "C2"); + let c3 = store.add_class(file_id, "C3"); + let pos = vec![Type::Class(c1), Type::Class(c2)]; + let neg = vec![Type::Class(c3)]; + let id = store.add_intersection(file_id, &pos, &neg); + assert_eq!( + store.get_intersection(id).positive, + pos.into_iter().collect::>() + ); + assert_eq!( + store.get_intersection(id).negative, + neg.into_iter().collect::>() + ); + let intersection = Type::Intersection(id); + assert_eq!( + format!("{}", intersection.display(&store)), + "(C1 & C2 & ~C3)" + ); + } +} diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs new file mode 100644 index 0000000000..eb44dc9cb1 --- /dev/null +++ b/crates/red_knot/src/types/infer.rs @@ -0,0 +1,141 @@ +#![allow(dead_code)] +use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::module::ModuleName; +use crate::symbols::{Definition, ImportFromDefinition, SymbolId}; +use crate::types::Type; +use crate::FileId; +use ruff_python_ast::AstNode; + +// TODO this should not take a &mut db, it should be a query, not a mutation. This means we'll need +// to use interior mutability in TypeStore instead, and avoid races in populating the cache. +#[tracing::instrument(level = "trace", skip(db))] +pub fn infer_symbol_type(db: &mut Db, file_id: FileId, symbol_id: SymbolId) -> Type +where + Db: SemanticDb + HasJar, +{ + let symbols = db.symbol_table(file_id); + let defs = symbols.defs(symbol_id); + + if let Some(ty) = db + .jar() + .type_store + .get_cached_symbol_type(file_id, symbol_id) + { + return ty; + } + + // TODO handle multiple defs, conditional defs... + assert_eq!(defs.len(), 1); + + let ty = match &defs[0] { + Definition::ImportFrom(ImportFromDefinition { + module, + name, + level, + }) => { + // TODO relative imports + assert!(matches!(level, 0)); + let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); + if let Some(module) = db.resolve_module(module_name) { + let remote_file_id = module.path(db).file(); + let remote_symbols = db.symbol_table(remote_file_id); + if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) { + db.infer_symbol_type(remote_file_id, remote_symbol_id) + } else { + Type::Unknown + } + } else { + Type::Unknown + } + } + Definition::ClassDef(node_key) => { + if let Some(ty) = db + .jar() + .type_store + .get_cached_node_type(file_id, node_key.erased()) + { + ty + } else { + let parsed = db.parse(file_id); + let ast = parsed.ast(); + let node = node_key.resolve_unwrap(ast.as_any_node_ref()); + + let store = &mut db.jar_mut().type_store; + let ty = Type::Class(store.add_class(file_id, &node.name.id)); + store.cache_node_type(file_id, *node_key.erased(), ty); + ty + } + } + _ => todo!("other kinds of definitions"), + }; + + db.jar_mut() + .type_store + .cache_symbol_type(file_id, symbol_id, ty); + // TODO record dependencies + ty +} + +#[cfg(test)] +mod tests { + use crate::db::tests::TestDb; + use crate::db::{HasJar, SemanticDb, SemanticJar}; + use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind}; + use crate::types::Type; + + // TODO with virtual filesystem we shouldn't have to write files to disk for these + // tests + + struct TestCase { + temp_dir: tempfile::TempDir, + db: TestDb, + + src: ModuleSearchPath, + } + + fn create_test() -> std::io::Result { + let temp_dir = tempfile::tempdir()?; + + let src = temp_dir.path().join("src"); + std::fs::create_dir(&src)?; + let src = ModuleSearchPath::new(src.canonicalize()?, ModuleSearchPathKind::FirstParty); + + let roots = vec![src.clone()]; + + let mut db = TestDb::default(); + db.set_module_search_paths(roots); + + Ok(TestCase { temp_dir, db, src }) + } + + #[test] + fn follow_import_to_class() -> std::io::Result<()> { + let TestCase { + src, + mut db, + temp_dir: _temp_dir, + } = create_test()?; + + let a_path = src.path().join("a.py"); + let b_path = src.path().join("b.py"); + std::fs::write(a_path, "from b import C as D")?; + std::fs::write(b_path, "class C: pass")?; + let a_file = db + .resolve_module(ModuleName::new("a")) + .expect("module should be found") + .path(&db) + .file(); + let a_syms = db.symbol_table(a_file); + let d_sym = a_syms + .root_symbol_id_by_name("D") + .expect("D symbol should be found"); + + let ty = db.infer_symbol_type(a_file, d_sym); + + let jar = HasJar::::jar(&db); + + assert!(matches!(ty, Type::Class(_))); + assert_eq!(format!("{}", ty.display(&jar.type_store)), "C"); + Ok(()) + } +} diff --git a/crates/red_knot/src/watch.rs b/crates/red_knot/src/watch.rs new file mode 100644 index 0000000000..80204a219b --- /dev/null +++ b/crates/red_knot/src/watch.rs @@ -0,0 +1,78 @@ +use anyhow::Context; +use std::path::Path; + +use crate::files::Files; +use crate::program::{FileChange, FileChangeKind}; +use notify::event::{CreateKind, RemoveKind}; +use notify::{recommended_watcher, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; + +pub struct FileWatcher { + watcher: RecommendedWatcher, +} + +pub trait EventHandler: Send + 'static { + fn handle(&self, changes: Vec); +} + +impl EventHandler for F +where + F: Fn(Vec) + Send + 'static, +{ + fn handle(&self, changes: Vec) { + let f = self; + f(changes); + } +} + +impl FileWatcher { + pub fn new(handler: E, files: Files) -> anyhow::Result + where + E: EventHandler, + { + Self::from_handler(Box::new(handler), files) + } + + fn from_handler(handler: Box, files: Files) -> anyhow::Result { + let watcher = recommended_watcher(move |changes: notify::Result| { + match changes { + Ok(event) => { + // TODO verify that this handles all events correctly + let change_kind = match event.kind { + EventKind::Create(CreateKind::File) => FileChangeKind::Created, + EventKind::Modify(_) => FileChangeKind::Modified, + EventKind::Remove(RemoveKind::File) => FileChangeKind::Deleted, + _ => { + return; + } + }; + + let mut changes = Vec::new(); + + for path in event.paths { + if path.is_file() { + let id = files.intern(&path); + changes.push(FileChange::new(id, change_kind)); + } + } + + if !changes.is_empty() { + handler.handle(changes); + } + } + // TODO proper error handling + Err(err) => { + panic!("Error: {err}"); + } + } + }) + .context("Failed to create file watcher.")?; + + Ok(Self { watcher }) + } + + pub fn watch_folder(&mut self, path: &Path) -> anyhow::Result<()> { + self.watcher.watch(path, RecursiveMode::Recursive)?; + + Ok(()) + } +} diff --git a/crates/ruff_cache/Cargo.toml b/crates/ruff_cache/Cargo.toml index 1584dc161c..b08ef03cfa 100644 --- a/crates/ruff_cache/Cargo.toml +++ b/crates/ruff_cache/Cargo.toml @@ -11,9 +11,9 @@ repository = { workspace = true } license = { workspace = true } [dependencies] -itertools = { workspace = true } glob = { workspace = true } globset = { workspace = true } +itertools = { workspace = true } regex = { workspace = true } filetime = { workspace = true } seahash = { workspace = true } diff --git a/crates/ruff_python_ast/src/node.rs b/crates/ruff_python_ast/src/node.rs index 8142ad4687..7d5024e420 100644 --- a/crates/ruff_python_ast/src/node.rs +++ b/crates/ruff_python_ast/src/node.rs @@ -2,17 +2,24 @@ use crate::visitor::preorder::PreorderVisitor; use crate::{ self as ast, Alias, ArgOrKeyword, Arguments, Comprehension, Decorator, ExceptHandler, Expr, FStringElement, Keyword, MatchCase, Mod, Parameter, ParameterWithDefault, Parameters, Pattern, - PatternArguments, PatternKeyword, Stmt, TypeParam, TypeParamParamSpec, TypeParamTypeVar, - TypeParamTypeVarTuple, TypeParams, WithItem, + PatternArguments, PatternKeyword, Stmt, StmtAnnAssign, StmtAssert, StmtAssign, StmtAugAssign, + StmtBreak, StmtClassDef, StmtContinue, StmtDelete, StmtExpr, StmtFor, StmtFunctionDef, + StmtGlobal, StmtIf, StmtImport, StmtImportFrom, StmtIpyEscapeCommand, StmtMatch, StmtNonlocal, + StmtPass, StmtRaise, StmtReturn, StmtTry, StmtTypeAlias, StmtWhile, StmtWith, TypeParam, + TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, TypeParams, WithItem, }; use ruff_text_size::{Ranged, TextRange}; use std::ptr::NonNull; pub trait AstNode: Ranged { + type Ref<'a>; + fn cast(kind: AnyNode) -> Option where Self: Sized; - fn cast_ref(kind: AnyNodeRef) -> Option<&Self>; + fn cast_ref(kind: AnyNodeRef<'_>) -> Option>; + + fn can_cast(kind: NodeKind) -> bool; /// Returns the [`AnyNodeRef`] referencing this node. fn as_any_node_ref(&self) -> AnyNodeRef; @@ -122,100 +129,7 @@ pub enum AnyNode { impl AnyNode { pub fn statement(self) -> Option { - match self { - AnyNode::StmtFunctionDef(node) => Some(Stmt::FunctionDef(node)), - AnyNode::StmtClassDef(node) => Some(Stmt::ClassDef(node)), - AnyNode::StmtReturn(node) => Some(Stmt::Return(node)), - AnyNode::StmtDelete(node) => Some(Stmt::Delete(node)), - AnyNode::StmtTypeAlias(node) => Some(Stmt::TypeAlias(node)), - AnyNode::StmtAssign(node) => Some(Stmt::Assign(node)), - AnyNode::StmtAugAssign(node) => Some(Stmt::AugAssign(node)), - AnyNode::StmtAnnAssign(node) => Some(Stmt::AnnAssign(node)), - AnyNode::StmtFor(node) => Some(Stmt::For(node)), - AnyNode::StmtWhile(node) => Some(Stmt::While(node)), - AnyNode::StmtIf(node) => Some(Stmt::If(node)), - AnyNode::StmtWith(node) => Some(Stmt::With(node)), - AnyNode::StmtMatch(node) => Some(Stmt::Match(node)), - AnyNode::StmtRaise(node) => Some(Stmt::Raise(node)), - AnyNode::StmtTry(node) => Some(Stmt::Try(node)), - AnyNode::StmtAssert(node) => Some(Stmt::Assert(node)), - AnyNode::StmtImport(node) => Some(Stmt::Import(node)), - AnyNode::StmtImportFrom(node) => Some(Stmt::ImportFrom(node)), - AnyNode::StmtGlobal(node) => Some(Stmt::Global(node)), - AnyNode::StmtNonlocal(node) => Some(Stmt::Nonlocal(node)), - AnyNode::StmtExpr(node) => Some(Stmt::Expr(node)), - AnyNode::StmtPass(node) => Some(Stmt::Pass(node)), - AnyNode::StmtBreak(node) => Some(Stmt::Break(node)), - AnyNode::StmtContinue(node) => Some(Stmt::Continue(node)), - AnyNode::StmtIpyEscapeCommand(node) => Some(Stmt::IpyEscapeCommand(node)), - - AnyNode::ModModule(_) - | AnyNode::ModExpression(_) - | AnyNode::ExprBoolOp(_) - | AnyNode::ExprNamed(_) - | AnyNode::ExprBinOp(_) - | AnyNode::ExprUnaryOp(_) - | AnyNode::ExprLambda(_) - | AnyNode::ExprIf(_) - | AnyNode::ExprDict(_) - | AnyNode::ExprSet(_) - | AnyNode::ExprListComp(_) - | AnyNode::ExprSetComp(_) - | AnyNode::ExprDictComp(_) - | AnyNode::ExprGenerator(_) - | AnyNode::ExprAwait(_) - | AnyNode::ExprYield(_) - | AnyNode::ExprYieldFrom(_) - | AnyNode::ExprCompare(_) - | AnyNode::ExprCall(_) - | AnyNode::FStringExpressionElement(_) - | AnyNode::FStringLiteralElement(_) - | AnyNode::FStringFormatSpec(_) - | AnyNode::ExprFString(_) - | AnyNode::ExprStringLiteral(_) - | AnyNode::ExprBytesLiteral(_) - | AnyNode::ExprNumberLiteral(_) - | AnyNode::ExprBooleanLiteral(_) - | AnyNode::ExprNoneLiteral(_) - | AnyNode::ExprEllipsisLiteral(_) - | AnyNode::ExprAttribute(_) - | AnyNode::ExprSubscript(_) - | AnyNode::ExprStarred(_) - | AnyNode::ExprName(_) - | AnyNode::ExprList(_) - | AnyNode::ExprTuple(_) - | AnyNode::ExprSlice(_) - | AnyNode::ExprIpyEscapeCommand(_) - | AnyNode::ExceptHandlerExceptHandler(_) - | AnyNode::PatternMatchValue(_) - | AnyNode::PatternMatchSingleton(_) - | AnyNode::PatternMatchSequence(_) - | AnyNode::PatternMatchMapping(_) - | AnyNode::PatternMatchClass(_) - | AnyNode::PatternMatchStar(_) - | AnyNode::PatternMatchAs(_) - | AnyNode::PatternMatchOr(_) - | AnyNode::PatternArguments(_) - | AnyNode::PatternKeyword(_) - | AnyNode::Comprehension(_) - | AnyNode::Arguments(_) - | AnyNode::Parameters(_) - | AnyNode::Parameter(_) - | AnyNode::ParameterWithDefault(_) - | AnyNode::Keyword(_) - | AnyNode::Alias(_) - | AnyNode::WithItem(_) - | AnyNode::MatchCase(_) - | AnyNode::Decorator(_) - | AnyNode::TypeParams(_) - | AnyNode::TypeParamTypeVar(_) - | AnyNode::TypeParamTypeVarTuple(_) - | AnyNode::TypeParamParamSpec(_) - | AnyNode::FString(_) - | AnyNode::StringLiteral(_) - | AnyNode::BytesLiteral(_) - | AnyNode::ElifElseClause(_) => None, - } + Stmt::cast(self) } pub fn expression(self) -> Option { @@ -729,6 +643,8 @@ impl AnyNode { } impl AstNode for ast::ModModule { + type Ref<'a> = &'a Self; + fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -748,6 +664,10 @@ impl AstNode for ast::ModModule { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ModModule) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -766,6 +686,8 @@ impl AstNode for ast::ModModule { } impl AstNode for ast::ModExpression { + type Ref<'a> = &'a Self; + fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -785,6 +707,10 @@ impl AstNode for ast::ModExpression { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ModExpression) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -802,6 +728,8 @@ impl AstNode for ast::ModExpression { } } impl AstNode for ast::StmtFunctionDef { + type Ref<'a> = &'a Self; + fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -821,6 +749,10 @@ impl AstNode for ast::StmtFunctionDef { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtFunctionDef) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -860,6 +792,8 @@ impl AstNode for ast::StmtFunctionDef { } } impl AstNode for ast::StmtClassDef { + type Ref<'a> = &'a Self; + fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -879,6 +813,10 @@ impl AstNode for ast::StmtClassDef { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtClassDef) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -915,6 +853,7 @@ impl AstNode for ast::StmtClassDef { } } impl AstNode for ast::StmtReturn { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -934,6 +873,10 @@ impl AstNode for ast::StmtReturn { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtReturn) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -953,6 +896,7 @@ impl AstNode for ast::StmtReturn { } } impl AstNode for ast::StmtDelete { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -972,6 +916,10 @@ impl AstNode for ast::StmtDelete { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtDelete) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -991,6 +939,7 @@ impl AstNode for ast::StmtDelete { } } impl AstNode for ast::StmtTypeAlias { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1010,6 +959,10 @@ impl AstNode for ast::StmtTypeAlias { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtTypeAlias) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1037,6 +990,7 @@ impl AstNode for ast::StmtTypeAlias { } } impl AstNode for ast::StmtAssign { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1056,6 +1010,10 @@ impl AstNode for ast::StmtAssign { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtAssign) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1082,6 +1040,7 @@ impl AstNode for ast::StmtAssign { } } impl AstNode for ast::StmtAugAssign { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1101,6 +1060,10 @@ impl AstNode for ast::StmtAugAssign { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtAugAssign) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1126,6 +1089,7 @@ impl AstNode for ast::StmtAugAssign { } } impl AstNode for ast::StmtAnnAssign { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1145,6 +1109,10 @@ impl AstNode for ast::StmtAnnAssign { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtAnnAssign) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1173,6 +1141,7 @@ impl AstNode for ast::StmtAnnAssign { } } impl AstNode for ast::StmtFor { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1192,6 +1161,10 @@ impl AstNode for ast::StmtFor { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtFor) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1219,6 +1192,7 @@ impl AstNode for ast::StmtFor { } } impl AstNode for ast::StmtWhile { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1238,6 +1212,10 @@ impl AstNode for ast::StmtWhile { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtWhile) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1263,6 +1241,7 @@ impl AstNode for ast::StmtWhile { } } impl AstNode for ast::StmtIf { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1282,6 +1261,10 @@ impl AstNode for ast::StmtIf { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtIf) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1309,6 +1292,7 @@ impl AstNode for ast::StmtIf { } } impl AstNode for ast::ElifElseClause { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1328,6 +1312,10 @@ impl AstNode for ast::ElifElseClause { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ElifElseClause) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1352,6 +1340,7 @@ impl AstNode for ast::ElifElseClause { } } impl AstNode for ast::StmtWith { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1371,6 +1360,10 @@ impl AstNode for ast::StmtWith { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtWith) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1397,6 +1390,7 @@ impl AstNode for ast::StmtWith { } } impl AstNode for ast::StmtMatch { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1416,6 +1410,10 @@ impl AstNode for ast::StmtMatch { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtMatch) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1441,6 +1439,7 @@ impl AstNode for ast::StmtMatch { } } impl AstNode for ast::StmtRaise { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1460,6 +1459,10 @@ impl AstNode for ast::StmtRaise { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtRaise) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1487,6 +1490,7 @@ impl AstNode for ast::StmtRaise { } } impl AstNode for ast::StmtTry { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1506,6 +1510,10 @@ impl AstNode for ast::StmtTry { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtTry) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1536,6 +1544,7 @@ impl AstNode for ast::StmtTry { } } impl AstNode for ast::StmtAssert { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1555,6 +1564,10 @@ impl AstNode for ast::StmtAssert { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtAssert) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1579,6 +1592,7 @@ impl AstNode for ast::StmtAssert { } } impl AstNode for ast::StmtImport { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1598,6 +1612,10 @@ impl AstNode for ast::StmtImport { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtImport) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1618,6 +1636,7 @@ impl AstNode for ast::StmtImport { } } impl AstNode for ast::StmtImportFrom { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1637,6 +1656,10 @@ impl AstNode for ast::StmtImportFrom { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtImportFrom) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1662,6 +1685,7 @@ impl AstNode for ast::StmtImportFrom { } } impl AstNode for ast::StmtGlobal { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1681,6 +1705,10 @@ impl AstNode for ast::StmtGlobal { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtGlobal) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1697,6 +1725,7 @@ impl AstNode for ast::StmtGlobal { } } impl AstNode for ast::StmtNonlocal { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1716,6 +1745,10 @@ impl AstNode for ast::StmtNonlocal { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtNonlocal) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1732,6 +1765,7 @@ impl AstNode for ast::StmtNonlocal { } } impl AstNode for ast::StmtExpr { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1751,6 +1785,10 @@ impl AstNode for ast::StmtExpr { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtExpr) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1769,6 +1807,7 @@ impl AstNode for ast::StmtExpr { } } impl AstNode for ast::StmtPass { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1788,6 +1827,10 @@ impl AstNode for ast::StmtPass { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtPass) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1804,6 +1847,7 @@ impl AstNode for ast::StmtPass { } } impl AstNode for ast::StmtBreak { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1823,6 +1867,10 @@ impl AstNode for ast::StmtBreak { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtBreak) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1839,6 +1887,7 @@ impl AstNode for ast::StmtBreak { } } impl AstNode for ast::StmtContinue { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1858,6 +1907,10 @@ impl AstNode for ast::StmtContinue { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtContinue) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1874,6 +1927,7 @@ impl AstNode for ast::StmtContinue { } } impl AstNode for ast::StmtIpyEscapeCommand { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1893,6 +1947,10 @@ impl AstNode for ast::StmtIpyEscapeCommand { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StmtIpyEscapeCommand) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1909,6 +1967,7 @@ impl AstNode for ast::StmtIpyEscapeCommand { } } impl AstNode for ast::ExprBoolOp { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1928,6 +1987,10 @@ impl AstNode for ast::ExprBoolOp { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprBoolOp) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -1960,6 +2023,7 @@ impl AstNode for ast::ExprBoolOp { } } impl AstNode for ast::ExprNamed { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -1979,6 +2043,10 @@ impl AstNode for ast::ExprNamed { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprNamed) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2001,6 +2069,7 @@ impl AstNode for ast::ExprNamed { } } impl AstNode for ast::ExprBinOp { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2020,6 +2089,10 @@ impl AstNode for ast::ExprBinOp { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprBinOp) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2044,6 +2117,7 @@ impl AstNode for ast::ExprBinOp { } } impl AstNode for ast::ExprUnaryOp { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2063,6 +2137,10 @@ impl AstNode for ast::ExprUnaryOp { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprUnaryOp) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2086,6 +2164,8 @@ impl AstNode for ast::ExprUnaryOp { } } impl AstNode for ast::ExprLambda { + type Ref<'a> = &'a Self; + fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2105,6 +2185,10 @@ impl AstNode for ast::ExprLambda { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprLambda) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2130,6 +2214,7 @@ impl AstNode for ast::ExprLambda { } } impl AstNode for ast::ExprIf { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2149,6 +2234,10 @@ impl AstNode for ast::ExprIf { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprIf) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2175,6 +2264,7 @@ impl AstNode for ast::ExprIf { } } impl AstNode for ast::ExprDict { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2194,6 +2284,10 @@ impl AstNode for ast::ExprDict { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprDict) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2221,6 +2315,7 @@ impl AstNode for ast::ExprDict { } } impl AstNode for ast::ExprSet { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2240,6 +2335,10 @@ impl AstNode for ast::ExprSet { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprSet) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2260,6 +2359,7 @@ impl AstNode for ast::ExprSet { } } impl AstNode for ast::ExprListComp { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2279,6 +2379,10 @@ impl AstNode for ast::ExprListComp { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprListComp) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2304,6 +2408,7 @@ impl AstNode for ast::ExprListComp { } } impl AstNode for ast::ExprSetComp { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2323,6 +2428,10 @@ impl AstNode for ast::ExprSetComp { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprSetComp) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2348,6 +2457,7 @@ impl AstNode for ast::ExprSetComp { } } impl AstNode for ast::ExprDictComp { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2367,6 +2477,10 @@ impl AstNode for ast::ExprDictComp { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprDictComp) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2395,6 +2509,7 @@ impl AstNode for ast::ExprDictComp { } } impl AstNode for ast::ExprGenerator { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2414,6 +2529,10 @@ impl AstNode for ast::ExprGenerator { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprGenerator) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2439,6 +2558,7 @@ impl AstNode for ast::ExprGenerator { } } impl AstNode for ast::ExprAwait { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2458,6 +2578,10 @@ impl AstNode for ast::ExprAwait { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprAwait) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2475,6 +2599,7 @@ impl AstNode for ast::ExprAwait { } } impl AstNode for ast::ExprYield { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2494,6 +2619,10 @@ impl AstNode for ast::ExprYield { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprYield) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2513,6 +2642,7 @@ impl AstNode for ast::ExprYield { } } impl AstNode for ast::ExprYieldFrom { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2532,6 +2662,10 @@ impl AstNode for ast::ExprYieldFrom { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprYieldFrom) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2549,6 +2683,7 @@ impl AstNode for ast::ExprYieldFrom { } } impl AstNode for ast::ExprCompare { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2568,6 +2703,10 @@ impl AstNode for ast::ExprCompare { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprCompare) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2596,6 +2735,7 @@ impl AstNode for ast::ExprCompare { } } impl AstNode for ast::ExprCall { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2615,6 +2755,10 @@ impl AstNode for ast::ExprCall { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprCall) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2637,6 +2781,7 @@ impl AstNode for ast::ExprCall { } } impl AstNode for ast::FStringFormatSpec { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2656,6 +2801,10 @@ impl AstNode for ast::FStringFormatSpec { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::FStringFormatSpec) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2674,6 +2823,7 @@ impl AstNode for ast::FStringFormatSpec { } } impl AstNode for ast::FStringExpressionElement { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2693,6 +2843,10 @@ impl AstNode for ast::FStringExpressionElement { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::FStringExpressionElement) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2720,6 +2874,7 @@ impl AstNode for ast::FStringExpressionElement { } } impl AstNode for ast::FStringLiteralElement { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2739,6 +2894,10 @@ impl AstNode for ast::FStringLiteralElement { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::FStringLiteralElement) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2754,6 +2913,7 @@ impl AstNode for ast::FStringLiteralElement { } } impl AstNode for ast::ExprFString { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2773,6 +2933,10 @@ impl AstNode for ast::ExprFString { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprFString) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2800,6 +2964,7 @@ impl AstNode for ast::ExprFString { } } impl AstNode for ast::ExprStringLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2819,6 +2984,10 @@ impl AstNode for ast::ExprStringLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprStringLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2839,6 +3008,7 @@ impl AstNode for ast::ExprStringLiteral { } } impl AstNode for ast::ExprBytesLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2858,6 +3028,10 @@ impl AstNode for ast::ExprBytesLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprBytesLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2878,6 +3052,7 @@ impl AstNode for ast::ExprBytesLiteral { } } impl AstNode for ast::ExprNumberLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2897,6 +3072,10 @@ impl AstNode for ast::ExprNumberLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprNumberLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2912,6 +3091,7 @@ impl AstNode for ast::ExprNumberLiteral { } } impl AstNode for ast::ExprBooleanLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2931,6 +3111,10 @@ impl AstNode for ast::ExprBooleanLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprBooleanLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2946,6 +3130,7 @@ impl AstNode for ast::ExprBooleanLiteral { } } impl AstNode for ast::ExprNoneLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2965,6 +3150,10 @@ impl AstNode for ast::ExprNoneLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprNoneLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -2980,6 +3169,7 @@ impl AstNode for ast::ExprNoneLiteral { } } impl AstNode for ast::ExprEllipsisLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -2999,6 +3189,10 @@ impl AstNode for ast::ExprEllipsisLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprEllipsisLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3014,6 +3208,7 @@ impl AstNode for ast::ExprEllipsisLiteral { } } impl AstNode for ast::ExprAttribute { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3033,6 +3228,10 @@ impl AstNode for ast::ExprAttribute { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprAttribute) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3056,6 +3255,7 @@ impl AstNode for ast::ExprAttribute { } } impl AstNode for ast::ExprSubscript { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3075,6 +3275,10 @@ impl AstNode for ast::ExprSubscript { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprSubscript) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3098,6 +3302,7 @@ impl AstNode for ast::ExprSubscript { } } impl AstNode for ast::ExprStarred { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3117,6 +3322,10 @@ impl AstNode for ast::ExprStarred { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprStarred) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3139,6 +3348,7 @@ impl AstNode for ast::ExprStarred { } } impl AstNode for ast::ExprName { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3158,6 +3368,10 @@ impl AstNode for ast::ExprName { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprName) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3179,6 +3393,7 @@ impl AstNode for ast::ExprName { } } impl AstNode for ast::ExprList { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3198,6 +3413,10 @@ impl AstNode for ast::ExprList { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprList) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3222,6 +3441,7 @@ impl AstNode for ast::ExprList { } } impl AstNode for ast::ExprTuple { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3241,6 +3461,10 @@ impl AstNode for ast::ExprTuple { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprTuple) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3266,6 +3490,7 @@ impl AstNode for ast::ExprTuple { } } impl AstNode for ast::ExprSlice { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3285,6 +3510,10 @@ impl AstNode for ast::ExprSlice { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprSlice) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3315,6 +3544,7 @@ impl AstNode for ast::ExprSlice { } } impl AstNode for ast::ExprIpyEscapeCommand { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3334,6 +3564,10 @@ impl AstNode for ast::ExprIpyEscapeCommand { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExprIpyEscapeCommand) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3355,6 +3589,7 @@ impl AstNode for ast::ExprIpyEscapeCommand { } } impl AstNode for ast::ExceptHandlerExceptHandler { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3374,6 +3609,10 @@ impl AstNode for ast::ExceptHandlerExceptHandler { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ExceptHandlerExceptHandler) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3399,6 +3638,7 @@ impl AstNode for ast::ExceptHandlerExceptHandler { } } impl AstNode for ast::PatternMatchValue { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3418,6 +3658,10 @@ impl AstNode for ast::PatternMatchValue { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchValue) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3435,6 +3679,7 @@ impl AstNode for ast::PatternMatchValue { } } impl AstNode for ast::PatternMatchSingleton { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3454,6 +3699,10 @@ impl AstNode for ast::PatternMatchSingleton { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchSingleton) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3471,6 +3720,7 @@ impl AstNode for ast::PatternMatchSingleton { } } impl AstNode for ast::PatternMatchSequence { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3490,6 +3740,10 @@ impl AstNode for ast::PatternMatchSequence { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchSequence) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3509,6 +3763,7 @@ impl AstNode for ast::PatternMatchSequence { } } impl AstNode for ast::PatternMatchMapping { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3528,6 +3783,10 @@ impl AstNode for ast::PatternMatchMapping { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchMapping) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3553,6 +3812,7 @@ impl AstNode for ast::PatternMatchMapping { } } impl AstNode for ast::PatternMatchClass { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3572,6 +3832,10 @@ impl AstNode for ast::PatternMatchClass { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchClass) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3594,6 +3858,7 @@ impl AstNode for ast::PatternMatchClass { } } impl AstNode for ast::PatternMatchStar { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3613,6 +3878,10 @@ impl AstNode for ast::PatternMatchStar { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchStar) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3630,6 +3899,7 @@ impl AstNode for ast::PatternMatchStar { } } impl AstNode for ast::PatternMatchAs { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3649,6 +3919,10 @@ impl AstNode for ast::PatternMatchAs { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchAs) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3672,6 +3946,7 @@ impl AstNode for ast::PatternMatchAs { } } impl AstNode for ast::PatternMatchOr { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3691,6 +3966,10 @@ impl AstNode for ast::PatternMatchOr { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternMatchOr) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3710,6 +3989,7 @@ impl AstNode for ast::PatternMatchOr { } } impl AstNode for PatternArguments { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3729,6 +4009,10 @@ impl AstNode for PatternArguments { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternArguments) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3757,6 +4041,7 @@ impl AstNode for PatternArguments { } } impl AstNode for PatternKeyword { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3776,6 +4061,10 @@ impl AstNode for PatternKeyword { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::PatternKeyword) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3799,6 +4088,7 @@ impl AstNode for PatternKeyword { } impl AstNode for Comprehension { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3818,6 +4108,10 @@ impl AstNode for Comprehension { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Comprehension) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3846,6 +4140,7 @@ impl AstNode for Comprehension { } } impl AstNode for Arguments { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3865,6 +4160,10 @@ impl AstNode for Arguments { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Arguments) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3886,6 +4185,7 @@ impl AstNode for Arguments { } } impl AstNode for Parameters { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3905,6 +4205,10 @@ impl AstNode for Parameters { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Parameters) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3943,6 +4247,7 @@ impl AstNode for Parameters { } } impl AstNode for Parameter { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -3962,6 +4267,10 @@ impl AstNode for Parameter { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Parameter) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -3986,6 +4295,7 @@ impl AstNode for Parameter { } } impl AstNode for ParameterWithDefault { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4005,6 +4315,10 @@ impl AstNode for ParameterWithDefault { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::ParameterWithDefault) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4029,6 +4343,7 @@ impl AstNode for ParameterWithDefault { } } impl AstNode for Keyword { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4048,6 +4363,10 @@ impl AstNode for Keyword { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Keyword) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4070,6 +4389,7 @@ impl AstNode for Keyword { } } impl AstNode for Alias { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4089,6 +4409,10 @@ impl AstNode for Alias { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Alias) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4110,6 +4434,7 @@ impl AstNode for Alias { } } impl AstNode for WithItem { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4129,6 +4454,10 @@ impl AstNode for WithItem { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::WithItem) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4155,6 +4484,7 @@ impl AstNode for WithItem { } } impl AstNode for MatchCase { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4174,6 +4504,10 @@ impl AstNode for MatchCase { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::MatchCase) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4202,6 +4536,7 @@ impl AstNode for MatchCase { } impl AstNode for Decorator { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4221,6 +4556,10 @@ impl AstNode for Decorator { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::Decorator) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4242,6 +4581,7 @@ impl AstNode for Decorator { } } impl AstNode for ast::TypeParams { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4261,6 +4601,10 @@ impl AstNode for ast::TypeParams { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::TypeParams) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4284,6 +4628,7 @@ impl AstNode for ast::TypeParams { } } impl AstNode for ast::TypeParamTypeVar { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4303,6 +4648,10 @@ impl AstNode for ast::TypeParamTypeVar { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::TypeParamTypeVar) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4331,6 +4680,7 @@ impl AstNode for ast::TypeParamTypeVar { } } impl AstNode for ast::TypeParamTypeVarTuple { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4350,6 +4700,10 @@ impl AstNode for ast::TypeParamTypeVarTuple { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::TypeParamTypeVarTuple) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4374,6 +4728,7 @@ impl AstNode for ast::TypeParamTypeVarTuple { } } impl AstNode for ast::TypeParamParamSpec { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4393,6 +4748,10 @@ impl AstNode for ast::TypeParamParamSpec { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::TypeParamParamSpec) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4417,6 +4776,7 @@ impl AstNode for ast::TypeParamParamSpec { } } impl AstNode for ast::FString { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4436,6 +4796,10 @@ impl AstNode for ast::FString { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::FString) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4460,6 +4824,7 @@ impl AstNode for ast::FString { } } impl AstNode for ast::StringLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4479,6 +4844,10 @@ impl AstNode for ast::StringLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::StringLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4494,6 +4863,7 @@ impl AstNode for ast::StringLiteral { } } impl AstNode for ast::BytesLiteral { + type Ref<'a> = &'a Self; fn cast(kind: AnyNode) -> Option where Self: Sized, @@ -4513,6 +4883,10 @@ impl AstNode for ast::BytesLiteral { } } + fn can_cast(kind: NodeKind) -> bool { + matches!(kind, NodeKind::BytesLiteral) + } + fn as_any_node_ref(&self) -> AnyNodeRef { AnyNodeRef::from(self) } @@ -4528,6 +4902,458 @@ impl AstNode for ast::BytesLiteral { } } +impl AstNode for Stmt { + type Ref<'a> = StatementRef<'a>; + + fn cast(kind: AnyNode) -> Option { + match kind { + AnyNode::StmtFunctionDef(node) => Some(Stmt::FunctionDef(node)), + AnyNode::StmtClassDef(node) => Some(Stmt::ClassDef(node)), + AnyNode::StmtReturn(node) => Some(Stmt::Return(node)), + AnyNode::StmtDelete(node) => Some(Stmt::Delete(node)), + AnyNode::StmtTypeAlias(node) => Some(Stmt::TypeAlias(node)), + AnyNode::StmtAssign(node) => Some(Stmt::Assign(node)), + AnyNode::StmtAugAssign(node) => Some(Stmt::AugAssign(node)), + AnyNode::StmtAnnAssign(node) => Some(Stmt::AnnAssign(node)), + AnyNode::StmtFor(node) => Some(Stmt::For(node)), + AnyNode::StmtWhile(node) => Some(Stmt::While(node)), + AnyNode::StmtIf(node) => Some(Stmt::If(node)), + AnyNode::StmtWith(node) => Some(Stmt::With(node)), + AnyNode::StmtMatch(node) => Some(Stmt::Match(node)), + AnyNode::StmtRaise(node) => Some(Stmt::Raise(node)), + AnyNode::StmtTry(node) => Some(Stmt::Try(node)), + AnyNode::StmtAssert(node) => Some(Stmt::Assert(node)), + AnyNode::StmtImport(node) => Some(Stmt::Import(node)), + AnyNode::StmtImportFrom(node) => Some(Stmt::ImportFrom(node)), + AnyNode::StmtGlobal(node) => Some(Stmt::Global(node)), + AnyNode::StmtNonlocal(node) => Some(Stmt::Nonlocal(node)), + AnyNode::StmtExpr(node) => Some(Stmt::Expr(node)), + AnyNode::StmtPass(node) => Some(Stmt::Pass(node)), + AnyNode::StmtBreak(node) => Some(Stmt::Break(node)), + AnyNode::StmtContinue(node) => Some(Stmt::Continue(node)), + AnyNode::StmtIpyEscapeCommand(node) => Some(Stmt::IpyEscapeCommand(node)), + + AnyNode::ModModule(_) + | AnyNode::ModExpression(_) + | AnyNode::ExprBoolOp(_) + | AnyNode::ExprNamed(_) + | AnyNode::ExprBinOp(_) + | AnyNode::ExprUnaryOp(_) + | AnyNode::ExprLambda(_) + | AnyNode::ExprIf(_) + | AnyNode::ExprDict(_) + | AnyNode::ExprSet(_) + | AnyNode::ExprListComp(_) + | AnyNode::ExprSetComp(_) + | AnyNode::ExprDictComp(_) + | AnyNode::ExprGenerator(_) + | AnyNode::ExprAwait(_) + | AnyNode::ExprYield(_) + | AnyNode::ExprYieldFrom(_) + | AnyNode::ExprCompare(_) + | AnyNode::ExprCall(_) + | AnyNode::FStringExpressionElement(_) + | AnyNode::FStringLiteralElement(_) + | AnyNode::FStringFormatSpec(_) + | AnyNode::ExprFString(_) + | AnyNode::ExprStringLiteral(_) + | AnyNode::ExprBytesLiteral(_) + | AnyNode::ExprNumberLiteral(_) + | AnyNode::ExprBooleanLiteral(_) + | AnyNode::ExprNoneLiteral(_) + | AnyNode::ExprEllipsisLiteral(_) + | AnyNode::ExprAttribute(_) + | AnyNode::ExprSubscript(_) + | AnyNode::ExprStarred(_) + | AnyNode::ExprName(_) + | AnyNode::ExprList(_) + | AnyNode::ExprTuple(_) + | AnyNode::ExprSlice(_) + | AnyNode::ExprIpyEscapeCommand(_) + | AnyNode::ExceptHandlerExceptHandler(_) + | AnyNode::PatternMatchValue(_) + | AnyNode::PatternMatchSingleton(_) + | AnyNode::PatternMatchSequence(_) + | AnyNode::PatternMatchMapping(_) + | AnyNode::PatternMatchClass(_) + | AnyNode::PatternMatchStar(_) + | AnyNode::PatternMatchAs(_) + | AnyNode::PatternMatchOr(_) + | AnyNode::PatternArguments(_) + | AnyNode::PatternKeyword(_) + | AnyNode::Comprehension(_) + | AnyNode::Arguments(_) + | AnyNode::Parameters(_) + | AnyNode::Parameter(_) + | AnyNode::ParameterWithDefault(_) + | AnyNode::Keyword(_) + | AnyNode::Alias(_) + | AnyNode::WithItem(_) + | AnyNode::MatchCase(_) + | AnyNode::Decorator(_) + | AnyNode::TypeParams(_) + | AnyNode::TypeParamTypeVar(_) + | AnyNode::TypeParamTypeVarTuple(_) + | AnyNode::TypeParamParamSpec(_) + | AnyNode::FString(_) + | AnyNode::StringLiteral(_) + | AnyNode::BytesLiteral(_) + | AnyNode::ElifElseClause(_) => None, + } + } + + fn cast_ref(kind: AnyNodeRef) -> Option> { + match kind { + AnyNodeRef::StmtFunctionDef(statement) => Some(StatementRef::FunctionDef(statement)), + AnyNodeRef::StmtClassDef(statement) => Some(StatementRef::ClassDef(statement)), + AnyNodeRef::StmtReturn(statement) => Some(StatementRef::Return(statement)), + AnyNodeRef::StmtDelete(statement) => Some(StatementRef::Delete(statement)), + AnyNodeRef::StmtTypeAlias(statement) => Some(StatementRef::TypeAlias(statement)), + AnyNodeRef::StmtAssign(statement) => Some(StatementRef::Assign(statement)), + AnyNodeRef::StmtAugAssign(statement) => Some(StatementRef::AugAssign(statement)), + AnyNodeRef::StmtAnnAssign(statement) => Some(StatementRef::AnnAssign(statement)), + AnyNodeRef::StmtFor(statement) => Some(StatementRef::For(statement)), + AnyNodeRef::StmtWhile(statement) => Some(StatementRef::While(statement)), + AnyNodeRef::StmtIf(statement) => Some(StatementRef::If(statement)), + AnyNodeRef::StmtWith(statement) => Some(StatementRef::With(statement)), + AnyNodeRef::StmtMatch(statement) => Some(StatementRef::Match(statement)), + AnyNodeRef::StmtRaise(statement) => Some(StatementRef::Raise(statement)), + AnyNodeRef::StmtTry(statement) => Some(StatementRef::Try(statement)), + AnyNodeRef::StmtAssert(statement) => Some(StatementRef::Assert(statement)), + AnyNodeRef::StmtImport(statement) => Some(StatementRef::Import(statement)), + AnyNodeRef::StmtImportFrom(statement) => Some(StatementRef::ImportFrom(statement)), + AnyNodeRef::StmtGlobal(statement) => Some(StatementRef::Global(statement)), + AnyNodeRef::StmtNonlocal(statement) => Some(StatementRef::Nonlocal(statement)), + AnyNodeRef::StmtExpr(statement) => Some(StatementRef::Expr(statement)), + AnyNodeRef::StmtPass(statement) => Some(StatementRef::Pass(statement)), + AnyNodeRef::StmtBreak(statement) => Some(StatementRef::Break(statement)), + AnyNodeRef::StmtContinue(statement) => Some(StatementRef::Continue(statement)), + AnyNodeRef::StmtIpyEscapeCommand(statement) => { + Some(StatementRef::IpyEscapeCommand(statement)) + } + AnyNodeRef::ModModule(_) + | AnyNodeRef::ModExpression(_) + | AnyNodeRef::ExprBoolOp(_) + | AnyNodeRef::ExprNamed(_) + | AnyNodeRef::ExprBinOp(_) + | AnyNodeRef::ExprUnaryOp(_) + | AnyNodeRef::ExprLambda(_) + | AnyNodeRef::ExprIf(_) + | AnyNodeRef::ExprDict(_) + | AnyNodeRef::ExprSet(_) + | AnyNodeRef::ExprListComp(_) + | AnyNodeRef::ExprSetComp(_) + | AnyNodeRef::ExprDictComp(_) + | AnyNodeRef::ExprGenerator(_) + | AnyNodeRef::ExprAwait(_) + | AnyNodeRef::ExprYield(_) + | AnyNodeRef::ExprYieldFrom(_) + | AnyNodeRef::ExprCompare(_) + | AnyNodeRef::ExprCall(_) + | AnyNodeRef::FStringExpressionElement(_) + | AnyNodeRef::FStringLiteralElement(_) + | AnyNodeRef::FStringFormatSpec(_) + | AnyNodeRef::ExprFString(_) + | AnyNodeRef::ExprStringLiteral(_) + | AnyNodeRef::ExprBytesLiteral(_) + | AnyNodeRef::ExprNumberLiteral(_) + | AnyNodeRef::ExprBooleanLiteral(_) + | AnyNodeRef::ExprNoneLiteral(_) + | AnyNodeRef::ExprEllipsisLiteral(_) + | AnyNodeRef::ExprAttribute(_) + | AnyNodeRef::ExprSubscript(_) + | AnyNodeRef::ExprStarred(_) + | AnyNodeRef::ExprName(_) + | AnyNodeRef::ExprList(_) + | AnyNodeRef::ExprTuple(_) + | AnyNodeRef::ExprSlice(_) + | AnyNodeRef::ExprIpyEscapeCommand(_) + | AnyNodeRef::ExceptHandlerExceptHandler(_) + | AnyNodeRef::PatternMatchValue(_) + | AnyNodeRef::PatternMatchSingleton(_) + | AnyNodeRef::PatternMatchSequence(_) + | AnyNodeRef::PatternMatchMapping(_) + | AnyNodeRef::PatternMatchClass(_) + | AnyNodeRef::PatternMatchStar(_) + | AnyNodeRef::PatternMatchAs(_) + | AnyNodeRef::PatternMatchOr(_) + | AnyNodeRef::PatternArguments(_) + | AnyNodeRef::PatternKeyword(_) + | AnyNodeRef::Comprehension(_) + | AnyNodeRef::Arguments(_) + | AnyNodeRef::Parameters(_) + | AnyNodeRef::Parameter(_) + | AnyNodeRef::ParameterWithDefault(_) + | AnyNodeRef::Keyword(_) + | AnyNodeRef::Alias(_) + | AnyNodeRef::WithItem(_) + | AnyNodeRef::MatchCase(_) + | AnyNodeRef::Decorator(_) + | AnyNodeRef::TypeParams(_) + | AnyNodeRef::TypeParamTypeVar(_) + | AnyNodeRef::TypeParamTypeVarTuple(_) + | AnyNodeRef::TypeParamParamSpec(_) + | AnyNodeRef::FString(_) + | AnyNodeRef::StringLiteral(_) + | AnyNodeRef::BytesLiteral(_) + | AnyNodeRef::ElifElseClause(_) => None, + } + } + + fn can_cast(kind: NodeKind) -> bool { + match kind { + NodeKind::StmtClassDef + | NodeKind::StmtReturn + | NodeKind::StmtDelete + | NodeKind::StmtTypeAlias + | NodeKind::StmtAssign + | NodeKind::StmtAugAssign + | NodeKind::StmtAnnAssign + | NodeKind::StmtFor + | NodeKind::StmtWhile + | NodeKind::StmtIf + | NodeKind::StmtWith + | NodeKind::StmtMatch + | NodeKind::StmtRaise + | NodeKind::StmtTry + | NodeKind::StmtAssert + | NodeKind::StmtImport + | NodeKind::StmtImportFrom + | NodeKind::StmtGlobal + | NodeKind::StmtNonlocal + | NodeKind::StmtIpyEscapeCommand + | NodeKind::StmtExpr + | NodeKind::StmtPass + | NodeKind::StmtBreak + | NodeKind::StmtContinue => true, + NodeKind::ExprBoolOp + | NodeKind::ModModule + | NodeKind::ModInteractive + | NodeKind::ModExpression + | NodeKind::ModFunctionType + | NodeKind::StmtFunctionDef + | NodeKind::ExprNamed + | NodeKind::ExprBinOp + | NodeKind::ExprUnaryOp + | NodeKind::ExprLambda + | NodeKind::ExprIf + | NodeKind::ExprDict + | NodeKind::ExprSet + | NodeKind::ExprListComp + | NodeKind::ExprSetComp + | NodeKind::ExprDictComp + | NodeKind::ExprGenerator + | NodeKind::ExprAwait + | NodeKind::ExprYield + | NodeKind::ExprYieldFrom + | NodeKind::ExprCompare + | NodeKind::ExprCall + | NodeKind::FStringExpressionElement + | NodeKind::FStringLiteralElement + | NodeKind::FStringFormatSpec + | NodeKind::ExprFString + | NodeKind::ExprStringLiteral + | NodeKind::ExprBytesLiteral + | NodeKind::ExprNumberLiteral + | NodeKind::ExprBooleanLiteral + | NodeKind::ExprNoneLiteral + | NodeKind::ExprEllipsisLiteral + | NodeKind::ExprAttribute + | NodeKind::ExprSubscript + | NodeKind::ExprStarred + | NodeKind::ExprName + | NodeKind::ExprList + | NodeKind::ExprTuple + | NodeKind::ExprSlice + | NodeKind::ExprIpyEscapeCommand + | NodeKind::ExceptHandlerExceptHandler + | NodeKind::PatternMatchValue + | NodeKind::PatternMatchSingleton + | NodeKind::PatternMatchSequence + | NodeKind::PatternMatchMapping + | NodeKind::PatternMatchClass + | NodeKind::PatternMatchStar + | NodeKind::PatternMatchAs + | NodeKind::PatternMatchOr + | NodeKind::PatternArguments + | NodeKind::PatternKeyword + | NodeKind::TypeIgnoreTypeIgnore + | NodeKind::Comprehension + | NodeKind::Arguments + | NodeKind::Parameters + | NodeKind::Parameter + | NodeKind::ParameterWithDefault + | NodeKind::Keyword + | NodeKind::Alias + | NodeKind::WithItem + | NodeKind::MatchCase + | NodeKind::Decorator + | NodeKind::ElifElseClause + | NodeKind::TypeParams + | NodeKind::TypeParamTypeVar + | NodeKind::TypeParamTypeVarTuple + | NodeKind::TypeParamParamSpec + | NodeKind::FString + | NodeKind::StringLiteral + | NodeKind::BytesLiteral => false, + } + } + + fn as_any_node_ref(&self) -> AnyNodeRef { + match self { + Stmt::FunctionDef(stmt) => stmt.as_any_node_ref(), + Stmt::ClassDef(stmt) => stmt.as_any_node_ref(), + Stmt::Return(stmt) => stmt.as_any_node_ref(), + Stmt::Delete(stmt) => stmt.as_any_node_ref(), + Stmt::Assign(stmt) => stmt.as_any_node_ref(), + Stmt::AugAssign(stmt) => stmt.as_any_node_ref(), + Stmt::AnnAssign(stmt) => stmt.as_any_node_ref(), + Stmt::TypeAlias(stmt) => stmt.as_any_node_ref(), + Stmt::For(stmt) => stmt.as_any_node_ref(), + Stmt::While(stmt) => stmt.as_any_node_ref(), + Stmt::If(stmt) => stmt.as_any_node_ref(), + Stmt::With(stmt) => stmt.as_any_node_ref(), + Stmt::Match(stmt) => stmt.as_any_node_ref(), + Stmt::Raise(stmt) => stmt.as_any_node_ref(), + Stmt::Try(stmt) => stmt.as_any_node_ref(), + Stmt::Assert(stmt) => stmt.as_any_node_ref(), + Stmt::Import(stmt) => stmt.as_any_node_ref(), + Stmt::ImportFrom(stmt) => stmt.as_any_node_ref(), + Stmt::Global(stmt) => stmt.as_any_node_ref(), + Stmt::Nonlocal(stmt) => stmt.as_any_node_ref(), + Stmt::Expr(stmt) => stmt.as_any_node_ref(), + Stmt::Pass(stmt) => stmt.as_any_node_ref(), + Stmt::Break(stmt) => stmt.as_any_node_ref(), + Stmt::Continue(stmt) => stmt.as_any_node_ref(), + Stmt::IpyEscapeCommand(stmt) => stmt.as_any_node_ref(), + } + } + + fn into_any_node(self) -> AnyNode { + match self { + Stmt::FunctionDef(stmt) => stmt.into_any_node(), + Stmt::ClassDef(stmt) => stmt.into_any_node(), + Stmt::Return(stmt) => stmt.into_any_node(), + Stmt::Delete(stmt) => stmt.into_any_node(), + Stmt::Assign(stmt) => stmt.into_any_node(), + Stmt::AugAssign(stmt) => stmt.into_any_node(), + Stmt::AnnAssign(stmt) => stmt.into_any_node(), + Stmt::TypeAlias(stmt) => stmt.into_any_node(), + Stmt::For(stmt) => stmt.into_any_node(), + Stmt::While(stmt) => stmt.into_any_node(), + Stmt::If(stmt) => stmt.into_any_node(), + Stmt::With(stmt) => stmt.into_any_node(), + Stmt::Match(stmt) => stmt.into_any_node(), + Stmt::Raise(stmt) => stmt.into_any_node(), + Stmt::Try(stmt) => stmt.into_any_node(), + Stmt::Assert(stmt) => stmt.into_any_node(), + Stmt::Import(stmt) => stmt.into_any_node(), + Stmt::ImportFrom(stmt) => stmt.into_any_node(), + Stmt::Global(stmt) => stmt.into_any_node(), + Stmt::Nonlocal(stmt) => stmt.into_any_node(), + Stmt::Expr(stmt) => stmt.into_any_node(), + Stmt::Pass(stmt) => stmt.into_any_node(), + Stmt::Break(stmt) => stmt.into_any_node(), + Stmt::Continue(stmt) => stmt.into_any_node(), + Stmt::IpyEscapeCommand(stmt) => stmt.into_any_node(), + } + } + + fn visit_preorder<'a, V>(&'a self, visitor: &mut V) + where + V: PreorderVisitor<'a> + ?Sized, + { + match self { + Stmt::FunctionDef(stmt) => stmt.visit_preorder(visitor), + Stmt::ClassDef(stmt) => stmt.visit_preorder(visitor), + Stmt::Return(stmt) => stmt.visit_preorder(visitor), + Stmt::Delete(stmt) => stmt.visit_preorder(visitor), + Stmt::Assign(stmt) => stmt.visit_preorder(visitor), + Stmt::AugAssign(stmt) => stmt.visit_preorder(visitor), + Stmt::AnnAssign(stmt) => stmt.visit_preorder(visitor), + Stmt::TypeAlias(stmt) => stmt.visit_preorder(visitor), + Stmt::For(stmt) => stmt.visit_preorder(visitor), + Stmt::While(stmt) => stmt.visit_preorder(visitor), + Stmt::If(stmt) => stmt.visit_preorder(visitor), + Stmt::With(stmt) => stmt.visit_preorder(visitor), + Stmt::Match(stmt) => stmt.visit_preorder(visitor), + Stmt::Raise(stmt) => stmt.visit_preorder(visitor), + Stmt::Try(stmt) => stmt.visit_preorder(visitor), + Stmt::Assert(stmt) => stmt.visit_preorder(visitor), + Stmt::Import(stmt) => stmt.visit_preorder(visitor), + Stmt::ImportFrom(stmt) => stmt.visit_preorder(visitor), + Stmt::Global(stmt) => stmt.visit_preorder(visitor), + Stmt::Nonlocal(stmt) => stmt.visit_preorder(visitor), + Stmt::Expr(stmt) => stmt.visit_preorder(visitor), + Stmt::Pass(stmt) => stmt.visit_preorder(visitor), + Stmt::Break(stmt) => stmt.visit_preorder(visitor), + Stmt::Continue(stmt) => stmt.visit_preorder(visitor), + Stmt::IpyEscapeCommand(stmt) => stmt.visit_preorder(visitor), + } + } +} + +impl AstNode for TypeParam { + type Ref<'a> = TypeParamRef<'a>; + + fn cast(kind: AnyNode) -> Option + where + Self: Sized, + { + match kind { + AnyNode::TypeParamTypeVar(node) => Some(TypeParam::TypeVar(node)), + AnyNode::TypeParamTypeVarTuple(node) => Some(TypeParam::TypeVarTuple(node)), + AnyNode::TypeParamParamSpec(node) => Some(TypeParam::ParamSpec(node)), + _ => None, + } + } + + fn cast_ref(kind: AnyNodeRef) -> Option> { + match kind { + AnyNodeRef::TypeParamTypeVar(node) => Some(TypeParamRef::TypeVar(node)), + AnyNodeRef::TypeParamTypeVarTuple(node) => Some(TypeParamRef::TypeVarTuple(node)), + AnyNodeRef::TypeParamParamSpec(node) => Some(TypeParamRef::ParamSpec(node)), + _ => None, + } + } + + fn can_cast(kind: NodeKind) -> bool { + matches!( + kind, + NodeKind::TypeParamTypeVar + | NodeKind::TypeParamTypeVarTuple + | NodeKind::TypeParamParamSpec + ) + } + + fn as_any_node_ref(&self) -> AnyNodeRef { + match self { + TypeParam::TypeVar(node) => node.as_any_node_ref(), + TypeParam::TypeVarTuple(node) => node.as_any_node_ref(), + TypeParam::ParamSpec(node) => node.as_any_node_ref(), + } + } + + fn into_any_node(self) -> AnyNode { + match self { + TypeParam::TypeVar(node) => node.into_any_node(), + TypeParam::TypeVarTuple(node) => node.into_any_node(), + TypeParam::ParamSpec(node) => node.into_any_node(), + } + } + + fn visit_preorder<'a, V>(&'a self, visitor: &mut V) + where + V: PreorderVisitor<'a> + ?Sized, + { + match self { + TypeParam::TypeVar(node) => node.visit_preorder(visitor), + TypeParam::TypeVarTuple(node) => node.visit_preorder(visitor), + TypeParam::ParamSpec(node) => node.visit_preorder(visitor), + } + } +} + impl From for AnyNode { fn from(stmt: Stmt) -> Self { match stmt { @@ -7192,3 +8018,232 @@ pub enum NodeKind { StringLiteral, BytesLiteral, } + +// FIXME: The `StatementRef` here allows us to implement `AstNode` for `Stmt` which otherwise wouldn't be possible +// because of the `cast_ref` method that needs to return a `&Stmt` for a specific statement node. +// Implementing `AstNode` for `Stmt` is desired to have `AstId.upcast` work where the Id then represents +// any `Stmt` instead of a specific statement. +// The existing solution "works" in the sense that `upcast` etc can be implemented. However, `StatementRef` +// doesn't implement `AstNode` itself and thus, can't be used as `AstNodeKey` or passed to query the `ast_id` (because that requires that the node implements `HasAstId` which extends `AstNode`). +// I don't know how a solution to this would look like but this isn't the first time where this problem has come up. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum StatementRef<'a> { + FunctionDef(&'a StmtFunctionDef), + ClassDef(&'a StmtClassDef), + Return(&'a StmtReturn), + Delete(&'a StmtDelete), + Assign(&'a StmtAssign), + AugAssign(&'a StmtAugAssign), + AnnAssign(&'a StmtAnnAssign), + TypeAlias(&'a StmtTypeAlias), + For(&'a StmtFor), + While(&'a StmtWhile), + If(&'a StmtIf), + With(&'a StmtWith), + Match(&'a StmtMatch), + Raise(&'a StmtRaise), + Try(&'a StmtTry), + Assert(&'a StmtAssert), + Import(&'a StmtImport), + ImportFrom(&'a StmtImportFrom), + Global(&'a StmtGlobal), + Nonlocal(&'a StmtNonlocal), + Expr(&'a StmtExpr), + Pass(&'a StmtPass), + Break(&'a StmtBreak), + Continue(&'a StmtContinue), + IpyEscapeCommand(&'a StmtIpyEscapeCommand), +} + +impl<'a> From<&'a StmtFunctionDef> for StatementRef<'a> { + fn from(value: &'a StmtFunctionDef) -> Self { + Self::FunctionDef(value) + } +} +impl<'a> From<&'a StmtClassDef> for StatementRef<'a> { + fn from(value: &'a StmtClassDef) -> Self { + Self::ClassDef(value) + } +} +impl<'a> From<&'a StmtReturn> for StatementRef<'a> { + fn from(value: &'a StmtReturn) -> Self { + Self::Return(value) + } +} +impl<'a> From<&'a StmtDelete> for StatementRef<'a> { + fn from(value: &'a StmtDelete) -> Self { + Self::Delete(value) + } +} +impl<'a> From<&'a StmtAssign> for StatementRef<'a> { + fn from(value: &'a StmtAssign) -> Self { + Self::Assign(value) + } +} +impl<'a> From<&'a StmtAugAssign> for StatementRef<'a> { + fn from(value: &'a StmtAugAssign) -> Self { + Self::AugAssign(value) + } +} +impl<'a> From<&'a StmtAnnAssign> for StatementRef<'a> { + fn from(value: &'a StmtAnnAssign) -> Self { + Self::AnnAssign(value) + } +} +impl<'a> From<&'a StmtTypeAlias> for StatementRef<'a> { + fn from(value: &'a StmtTypeAlias) -> Self { + Self::TypeAlias(value) + } +} +impl<'a> From<&'a StmtFor> for StatementRef<'a> { + fn from(value: &'a StmtFor) -> Self { + Self::For(value) + } +} +impl<'a> From<&'a StmtWhile> for StatementRef<'a> { + fn from(value: &'a StmtWhile) -> Self { + Self::While(value) + } +} +impl<'a> From<&'a StmtIf> for StatementRef<'a> { + fn from(value: &'a StmtIf) -> Self { + Self::If(value) + } +} +impl<'a> From<&'a StmtWith> for StatementRef<'a> { + fn from(value: &'a StmtWith) -> Self { + Self::With(value) + } +} +impl<'a> From<&'a StmtMatch> for StatementRef<'a> { + fn from(value: &'a StmtMatch) -> Self { + Self::Match(value) + } +} +impl<'a> From<&'a StmtRaise> for StatementRef<'a> { + fn from(value: &'a StmtRaise) -> Self { + Self::Raise(value) + } +} +impl<'a> From<&'a StmtTry> for StatementRef<'a> { + fn from(value: &'a StmtTry) -> Self { + Self::Try(value) + } +} +impl<'a> From<&'a StmtAssert> for StatementRef<'a> { + fn from(value: &'a StmtAssert) -> Self { + Self::Assert(value) + } +} +impl<'a> From<&'a StmtImport> for StatementRef<'a> { + fn from(value: &'a StmtImport) -> Self { + Self::Import(value) + } +} +impl<'a> From<&'a StmtImportFrom> for StatementRef<'a> { + fn from(value: &'a StmtImportFrom) -> Self { + Self::ImportFrom(value) + } +} +impl<'a> From<&'a StmtGlobal> for StatementRef<'a> { + fn from(value: &'a StmtGlobal) -> Self { + Self::Global(value) + } +} +impl<'a> From<&'a StmtNonlocal> for StatementRef<'a> { + fn from(value: &'a StmtNonlocal) -> Self { + Self::Nonlocal(value) + } +} +impl<'a> From<&'a StmtExpr> for StatementRef<'a> { + fn from(value: &'a StmtExpr) -> Self { + Self::Expr(value) + } +} +impl<'a> From<&'a StmtPass> for StatementRef<'a> { + fn from(value: &'a StmtPass) -> Self { + Self::Pass(value) + } +} +impl<'a> From<&'a StmtBreak> for StatementRef<'a> { + fn from(value: &'a StmtBreak) -> Self { + Self::Break(value) + } +} +impl<'a> From<&'a StmtContinue> for StatementRef<'a> { + fn from(value: &'a StmtContinue) -> Self { + Self::Continue(value) + } +} +impl<'a> From<&'a StmtIpyEscapeCommand> for StatementRef<'a> { + fn from(value: &'a StmtIpyEscapeCommand) -> Self { + Self::IpyEscapeCommand(value) + } +} + +impl<'a> From<&'a Stmt> for StatementRef<'a> { + fn from(value: &'a Stmt) -> Self { + match value { + Stmt::FunctionDef(statement) => Self::FunctionDef(statement), + Stmt::ClassDef(statement) => Self::ClassDef(statement), + Stmt::Return(statement) => Self::Return(statement), + Stmt::Delete(statement) => Self::Delete(statement), + Stmt::Assign(statement) => Self::Assign(statement), + Stmt::AugAssign(statement) => Self::AugAssign(statement), + Stmt::AnnAssign(statement) => Self::AnnAssign(statement), + Stmt::TypeAlias(statement) => Self::TypeAlias(statement), + Stmt::For(statement) => Self::For(statement), + Stmt::While(statement) => Self::While(statement), + Stmt::If(statement) => Self::If(statement), + Stmt::With(statement) => Self::With(statement), + Stmt::Match(statement) => Self::Match(statement), + Stmt::Raise(statement) => Self::Raise(statement), + Stmt::Try(statement) => Self::Try(statement), + Stmt::Assert(statement) => Self::Assert(statement), + Stmt::Import(statement) => Self::Import(statement), + Stmt::ImportFrom(statement) => Self::ImportFrom(statement), + Stmt::Global(statement) => Self::Global(statement), + Stmt::Nonlocal(statement) => Self::Nonlocal(statement), + Stmt::Expr(statement) => Self::Expr(statement), + Stmt::Pass(statement) => Self::Pass(statement), + Stmt::Break(statement) => Self::Break(statement), + Stmt::Continue(statement) => Self::Continue(statement), + Stmt::IpyEscapeCommand(statement) => Self::IpyEscapeCommand(statement), + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum TypeParamRef<'a> { + TypeVar(&'a TypeParamTypeVar), + ParamSpec(&'a TypeParamParamSpec), + TypeVarTuple(&'a TypeParamTypeVarTuple), +} + +impl<'a> From<&'a TypeParamTypeVar> for TypeParamRef<'a> { + fn from(value: &'a TypeParamTypeVar) -> Self { + Self::TypeVar(value) + } +} + +impl<'a> From<&'a TypeParamParamSpec> for TypeParamRef<'a> { + fn from(value: &'a TypeParamParamSpec) -> Self { + Self::ParamSpec(value) + } +} + +impl<'a> From<&'a TypeParamTypeVarTuple> for TypeParamRef<'a> { + fn from(value: &'a TypeParamTypeVarTuple) -> Self { + Self::TypeVarTuple(value) + } +} + +impl<'a> From<&'a TypeParam> for TypeParamRef<'a> { + fn from(value: &'a TypeParam) -> Self { + match value { + TypeParam::TypeVar(value) => Self::TypeVar(value), + TypeParam::ParamSpec(value) => Self::ParamSpec(value), + TypeParam::TypeVarTuple(value) => Self::TypeVarTuple(value), + } + } +} diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index cf2eac2ad8..ff54724a25 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -1,11 +1,10 @@ #![allow(clippy::derive_partial_eq_without_eq)] -use std::cell::OnceCell; - use std::fmt; use std::fmt::Debug; use std::ops::Deref; use std::slice::{Iter, IterMut}; +use std::sync::OnceLock; use bitflags::bitflags; use itertools::Itertools; @@ -1420,7 +1419,7 @@ impl StringLiteralValue { Self { inner: StringLiteralValueInner::Concatenated(ConcatenatedStringLiteral { strings, - value: OnceCell::new(), + value: OnceLock::new(), }), } } @@ -1782,7 +1781,7 @@ struct ConcatenatedStringLiteral { /// Each string literal that makes up the concatenated string. strings: Vec, /// The concatenated string value. - value: OnceCell>, + value: OnceLock>, } impl ConcatenatedStringLiteral { @@ -4168,7 +4167,7 @@ mod tests { assert_eq!(std::mem::size_of::(), 40); assert_eq!(std::mem::size_of::(), 32); assert_eq!(std::mem::size_of::(), 24); - assert_eq!(std::mem::size_of::(), 48); + assert_eq!(std::mem::size_of::(), 56); assert_eq!(std::mem::size_of::(), 32); assert_eq!(std::mem::size_of::(), 40); assert_eq!(std::mem::size_of::(), 24);