From bc03d376e802ac68930f3f1b3f9ddd1cdd27ccf3 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 30 Apr 2024 09:13:26 +0200 Subject: [PATCH] [red-knot] Add "cheap" `program.snapshot` (#11172) --- Cargo.lock | 26 +- Cargo.toml | 2 +- crates/red_knot/Cargo.toml | 2 +- crates/red_knot/src/cache.rs | 11 +- crates/red_knot/src/cancellation.rs | 42 +--- crates/red_knot/src/db.rs | 197 +++++++++++---- crates/red_knot/src/db/jars.rs | 37 +++ crates/red_knot/src/db/query.rs | 20 ++ crates/red_knot/src/db/runtime.rs | 41 +++ crates/red_knot/src/db/storage.rs | 117 +++++++++ crates/red_knot/src/lib.rs | 2 +- crates/red_knot/src/lint.rs | 50 ++-- crates/red_knot/src/main.rs | 147 +++++------ crates/red_knot/src/module.rs | 234 ++++++++++-------- crates/red_knot/src/parse.rs | 10 +- crates/red_knot/src/program/check.rs | 145 +++++------ crates/red_knot/src/program/mod.rs | 122 +++++---- crates/red_knot/src/source.rs | 8 +- crates/red_knot/src/symbols.rs | 10 +- crates/red_knot/src/types/infer.rs | 108 ++++---- crates/ruff_server/Cargo.toml | 2 +- crates/ruff_server/src/server/client.rs | 2 +- crates/ruff_server/src/server/schedule.rs | 2 +- .../src/server/schedule/thread/pool.rs | 4 +- 24 files changed, 833 insertions(+), 508 deletions(-) create mode 100644 crates/red_knot/src/db/jars.rs create mode 100644 crates/red_knot/src/db/query.rs create mode 100644 crates/red_knot/src/db/runtime.rs create mode 100644 crates/red_knot/src/db/storage.rs diff --git a/Cargo.lock b/Cargo.lock index f9181f7c5d..421ae66f65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -501,6 +501,19 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.12" @@ -529,6 +542,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.19" @@ -1804,7 +1826,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bitflags 2.5.0", - "crossbeam-channel", + "crossbeam", "ctrlc", "dashmap", "hashbrown 0.14.5", @@ -2341,7 +2363,7 @@ name = "ruff_server" version = "0.2.2" dependencies = [ "anyhow", - "crossbeam-channel", + "crossbeam", "insta", "jod-thread", "libc", diff --git a/Cargo.toml b/Cargo.toml index 95681c8b13..77b49be430 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ console_error_panic_hook = { version = "0.1.7" } console_log = { version = "1.0.0" } countme = { version = "3.0.1" } criterion = { version = "0.5.1", default-features = false } -crossbeam-channel = { version = "0.5.12" } +crossbeam = { version = "0.8.4" } dashmap = { version = "5.5.3" } dirs = { version = "5.0.0" } drop_bomb = { version = "0.1.5" } diff --git a/crates/red_knot/Cargo.toml b/crates/red_knot/Cargo.toml index 7907c8340a..382d7d3062 100644 --- a/crates/red_knot/Cargo.toml +++ b/crates/red_knot/Cargo.toml @@ -22,7 +22,7 @@ ruff_notebook = { path = "../ruff_notebook" } anyhow = { workspace = true } bitflags = { workspace = true } ctrlc = "3.4.4" -crossbeam-channel = { workspace = true } +crossbeam = { workspace = true } dashmap = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/crates/red_knot/src/cache.rs b/crates/red_knot/src/cache.rs index 8f63e8acfb..719a1449ed 100644 --- a/crates/red_knot/src/cache.rs +++ b/crates/red_knot/src/cache.rs @@ -2,6 +2,7 @@ use std::fmt::Formatter; use std::hash::Hash; use std::sync::atomic::{AtomicUsize, Ordering}; +use crate::db::QueryResult; use dashmap::mapref::entry::Entry; use crate::FxDashMap; @@ -27,11 +28,11 @@ where } } - pub fn get(&self, key: &K, compute: F) -> V + pub fn get(&self, key: &K, compute: F) -> QueryResult where - F: FnOnce(&K) -> V, + F: FnOnce(&K) -> QueryResult, { - match self.map.entry(key.clone()) { + Ok(match self.map.entry(key.clone()) { Entry::Occupied(cached) => { self.statistics.hit(); @@ -40,11 +41,11 @@ where Entry::Vacant(vacant) => { self.statistics.miss(); - let value = compute(key); + let value = compute(key)?; vacant.insert(value.clone()); value } - } + }) } pub fn set(&mut self, key: K, value: V) { diff --git a/crates/red_knot/src/cancellation.rs b/crates/red_knot/src/cancellation.rs index 0620d86ab5..6f91bc8e2b 100644 --- a/crates/red_knot/src/cancellation.rs +++ b/crates/red_knot/src/cancellation.rs @@ -1,35 +1,25 @@ -use std::sync::{Arc, Condvar, Mutex}; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; -#[derive(Debug, Default)] +#[derive(Debug, Clone, Default)] pub struct CancellationTokenSource { - signal: Arc<(Mutex, Condvar)>, + signal: Arc, } impl CancellationTokenSource { pub fn new() -> Self { Self { - signal: Arc::new((Mutex::new(false), Condvar::default())), + signal: Arc::new(AtomicBool::new(false)), } } #[tracing::instrument(level = "trace", skip_all)] pub fn cancel(&self) { - let (cancelled, condvar) = &*self.signal; - - let mut cancelled = cancelled.lock().unwrap(); - - if *cancelled { - return; - } - - *cancelled = true; - condvar.notify_all(); + self.signal.store(true, std::sync::atomic::Ordering::SeqCst); } pub fn is_cancelled(&self) -> bool { - let (cancelled, _) = &*self.signal; - - *cancelled.lock().unwrap() + self.signal.load(std::sync::atomic::Ordering::SeqCst) } pub fn token(&self) -> CancellationToken { @@ -41,26 +31,12 @@ impl CancellationTokenSource { #[derive(Clone, Debug)] pub struct CancellationToken { - signal: Arc<(Mutex, Condvar)>, + signal: Arc, } impl CancellationToken { /// Returns `true` if cancellation has been requested. pub fn is_cancelled(&self) -> bool { - let (cancelled, _) = &*self.signal; - - *cancelled.lock().unwrap() - } - - pub fn wait(&self) { - let (bool, condvar) = &*self.signal; - - let lock = condvar - .wait_while(bool.lock().unwrap(), |bool| !*bool) - .unwrap(); - - debug_assert!(*lock); - - drop(lock); + self.signal.load(std::sync::atomic::Ordering::SeqCst) } } diff --git a/crates/red_knot/src/db.rs b/crates/red_knot/src/db.rs index 21685a5734..1ccf6f979e 100644 --- a/crates/red_knot/src/db.rs +++ b/crates/red_knot/src/db.rs @@ -1,3 +1,8 @@ +mod jars; +mod query; +mod runtime; +mod storage; + use std::path::Path; use std::sync::Arc; @@ -9,32 +14,115 @@ use crate::source::{Source, SourceStorage}; use crate::symbols::{SymbolId, SymbolTable, SymbolTablesStorage}; use crate::types::{Type, TypeStore}; -pub trait SourceDb { +pub use jars::{HasJar, HasJars}; +pub use query::{QueryError, QueryResult}; +pub use runtime::DbRuntime; +pub use storage::JarsStorage; + +pub trait Database { + /// Returns a reference to the runtime of the current worker. + fn runtime(&self) -> &DbRuntime; + + /// Returns a mutable reference to the runtime. Only one worker can hold a mutable reference to the runtime. + fn runtime_mut(&mut self) -> &mut DbRuntime; + + /// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise. + fn cancelled(&self) -> QueryResult<()> { + self.runtime().cancelled() + } + + /// Returns `true` if the queries have been cancelled. + fn is_cancelled(&self) -> bool { + self.runtime().is_cancelled() + } +} + +/// Database that supports running queries from multiple threads. +pub trait ParallelDatabase: Database + Send { + /// Creates a snapshot of the database state that can be used to query the database in another thread. + /// + /// The snapshot is a read-only view of the database but query results are shared between threads. + /// All queries will be automatically cancelled when applying any mutations (calling [`HasJars::jars_mut`]) + /// to the database (not the snapshot, because they're readonly). + /// + /// ## Creating a snapshot + /// + /// Creating a snapshot of the database's jars is cheap but creating a snapshot of + /// other state stored on the database might require deep-cloning data. That's why you should + /// avoid creating snapshots in a hot function (e.g. don't create a snapshot for each file, instead + /// create a snapshot when scheduling the check of an entire program). + /// + /// ## Salsa compatibility + /// Salsa prohibits creating a snapshot while running a local query (it's fine if other workers run a query) [[source](https://github.com/salsa-rs/salsa/issues/80)]. + /// We should avoid creating snapshots while running a query because we might want to adopt Salsa in the future (if we can figure out persistent caching). + /// Unfortunately, the infrastructure doesn't provide an automated way of knowing when a query is run, that's + /// why we have to "enforce" this constraint manually. + fn snapshot(&self) -> Snapshot; +} + +/// Readonly snapshot of a database. +/// +/// ## Dead locks +/// A snapshot should always be dropped as soon as it is no longer necessary to run queries. +/// Storing the snapshot without running a query or periodically checking if cancellation was requested +/// can lead to deadlocks because mutating the [`Database`] requires cancels all pending queries +/// and waiting for all [`Snapshot`]s to be dropped. +#[derive(Debug)] +pub struct Snapshot +where + DB: ParallelDatabase, +{ + db: DB, +} + +impl Snapshot +where + DB: ParallelDatabase, +{ + pub fn new(db: DB) -> Self { + Snapshot { db } + } +} + +impl std::ops::Deref for Snapshot +where + DB: ParallelDatabase, +{ + type Target = DB; + + fn deref(&self) -> &DB { + &self.db + } +} + +// Red knot specific databases code. + +pub trait SourceDb: Database { // queries fn file_id(&self, path: &std::path::Path) -> FileId; fn file_path(&self, file_id: FileId) -> Arc; - fn source(&self, file_id: FileId) -> Source; + fn source(&self, file_id: FileId) -> QueryResult; - fn parse(&self, file_id: FileId) -> Parsed; + fn parse(&self, file_id: FileId) -> QueryResult; - fn lint_syntax(&self, file_id: FileId) -> Diagnostics; + fn lint_syntax(&self, file_id: FileId) -> QueryResult; } pub trait SemanticDb: SourceDb { // queries - fn resolve_module(&self, name: ModuleName) -> Option; + fn resolve_module(&self, name: ModuleName) -> QueryResult>; - fn file_to_module(&self, file_id: FileId) -> Option; + fn file_to_module(&self, file_id: FileId) -> QueryResult>; - fn path_to_module(&self, path: &Path) -> Option; + fn path_to_module(&self, path: &Path) -> QueryResult>; - fn symbol_table(&self, file_id: FileId) -> Arc; + fn symbol_table(&self, file_id: FileId) -> QueryResult>; - fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type; + fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult; - fn lint_semantic(&self, file_id: FileId) -> Diagnostics; + fn lint_semantic(&self, file_id: FileId) -> QueryResult; // mutations @@ -60,32 +148,15 @@ pub struct SemanticJar { pub lint_semantic: LintSemanticStorage, } -/// Gives access to a specific jar in the database. -/// -/// Nope, the terminology isn't borrowed from Java but from Salsa , -/// which is an analogy to storing the salsa in different jars. -/// -/// The basic idea is that each crate can define its own jar and the jars can be combined to a single -/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of -/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels. -/// -/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache). -/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide -/// to use a macro to generate the queries. -pub trait HasJar { - /// Gives a read-only reference to the jar. - fn jar(&self) -> &T; - - /// Gives a mutable reference to the jar. - fn jar_mut(&mut self) -> &mut T; -} - #[cfg(test)] pub(crate) mod tests { use std::path::Path; use std::sync::Arc; - use crate::db::{HasJar, SourceDb, SourceJar}; + use crate::db::{ + Database, DbRuntime, HasJar, HasJars, JarsStorage, ParallelDatabase, QueryResult, Snapshot, + SourceDb, SourceJar, + }; use crate::files::{FileId, Files}; use crate::lint::{lint_semantic, lint_syntax, Diagnostics}; use crate::module::{ @@ -104,27 +175,26 @@ pub(crate) mod tests { #[derive(Debug, Default)] pub(crate) struct TestDb { files: Files, - source: SourceJar, - semantic: SemanticJar, + jars: JarsStorage, } impl HasJar for TestDb { - fn jar(&self) -> &SourceJar { - &self.source + fn jar(&self) -> QueryResult<&SourceJar> { + Ok(&self.jars()?.0) } fn jar_mut(&mut self) -> &mut SourceJar { - &mut self.source + &mut self.jars_mut().0 } } impl HasJar for TestDb { - fn jar(&self) -> &SemanticJar { - &self.semantic + fn jar(&self) -> QueryResult<&SemanticJar> { + Ok(&self.jars()?.1) } fn jar_mut(&mut self) -> &mut SemanticJar { - &mut self.semantic + &mut self.jars_mut().1 } } @@ -137,41 +207,41 @@ pub(crate) mod tests { self.files.path(file_id) } - fn source(&self, file_id: FileId) -> Source { + fn source(&self, file_id: FileId) -> QueryResult { source_text(self, file_id) } - fn parse(&self, file_id: FileId) -> Parsed { + fn parse(&self, file_id: FileId) -> QueryResult { parse(self, file_id) } - fn lint_syntax(&self, file_id: FileId) -> Diagnostics { + fn lint_syntax(&self, file_id: FileId) -> QueryResult { lint_syntax(self, file_id) } } impl SemanticDb for TestDb { - fn resolve_module(&self, name: ModuleName) -> Option { + fn resolve_module(&self, name: ModuleName) -> QueryResult> { resolve_module(self, name) } - fn file_to_module(&self, file_id: FileId) -> Option { + fn file_to_module(&self, file_id: FileId) -> QueryResult> { file_to_module(self, file_id) } - fn path_to_module(&self, path: &Path) -> Option { + fn path_to_module(&self, path: &Path) -> QueryResult> { path_to_module(self, path) } - fn symbol_table(&self, file_id: FileId) -> Arc { + fn symbol_table(&self, file_id: FileId) -> QueryResult> { symbol_table(self, file_id) } - fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type { + fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult { infer_symbol_type(self, file_id, symbol_id) } - fn lint_semantic(&self, file_id: FileId) -> Diagnostics { + fn lint_semantic(&self, file_id: FileId) -> QueryResult { lint_semantic(self, file_id) } @@ -183,4 +253,35 @@ pub(crate) mod tests { set_module_search_paths(self, paths); } } + + impl HasJars for TestDb { + type Jars = (SourceJar, SemanticJar); + + fn jars(&self) -> QueryResult<&Self::Jars> { + self.jars.jars() + } + + fn jars_mut(&mut self) -> &mut Self::Jars { + self.jars.jars_mut() + } + } + + impl Database for TestDb { + fn runtime(&self) -> &DbRuntime { + self.jars.runtime() + } + + fn runtime_mut(&mut self) -> &mut DbRuntime { + self.jars.runtime_mut() + } + } + + impl ParallelDatabase for TestDb { + fn snapshot(&self) -> Snapshot { + Snapshot::new(Self { + files: self.files.clone(), + jars: self.jars.snapshot(), + }) + } + } } diff --git a/crates/red_knot/src/db/jars.rs b/crates/red_knot/src/db/jars.rs new file mode 100644 index 0000000000..f67b7cd651 --- /dev/null +++ b/crates/red_knot/src/db/jars.rs @@ -0,0 +1,37 @@ +use crate::db::query::QueryResult; + +/// Gives access to a specific jar in the database. +/// +/// Nope, the terminology isn't borrowed from Java but from Salsa , +/// which is an analogy to storing the salsa in different jars. +/// +/// The basic idea is that each crate can define its own jar and the jars can be combined to a single +/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of +/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels. +/// +/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache). +/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide +/// to use a macro to generate the queries. +pub trait HasJar { + /// Gives a read-only reference to the jar. + fn jar(&self) -> QueryResult<&T>; + + /// Gives a mutable reference to the jar. + fn jar_mut(&mut self) -> &mut T; +} + +/// Gives access to the jars in a database. +pub trait HasJars { + /// A type storing the jars. + /// + /// Most commonly, this is a tuple where each jar is a tuple element. + type Jars: Default; + + /// Gives access to the underlying jars but tests if the queries have been cancelled. + /// + /// Returns `Err(QueryError::Cancelled)` if the queries have been cancelled. + fn jars(&self) -> QueryResult<&Self::Jars>; + + /// Gives mutable access to the underlying jars. + fn jars_mut(&mut self) -> &mut Self::Jars; +} diff --git a/crates/red_knot/src/db/query.rs b/crates/red_knot/src/db/query.rs new file mode 100644 index 0000000000..d020decd6e --- /dev/null +++ b/crates/red_knot/src/db/query.rs @@ -0,0 +1,20 @@ +use std::fmt::{Display, Formatter}; + +/// Reason why a db query operation failed. +#[derive(Debug, Clone, Copy)] +pub enum QueryError { + /// The query was cancelled because the DB was mutated or the query was cancelled by the host (e.g. on a file change or when pressing CTRL+C). + Cancelled, +} + +impl Display for QueryError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + QueryError::Cancelled => f.write_str("query was cancelled"), + } + } +} + +impl std::error::Error for QueryError {} + +pub type QueryResult = Result; diff --git a/crates/red_knot/src/db/runtime.rs b/crates/red_knot/src/db/runtime.rs new file mode 100644 index 0000000000..c8530eb168 --- /dev/null +++ b/crates/red_knot/src/db/runtime.rs @@ -0,0 +1,41 @@ +use crate::cancellation::CancellationTokenSource; +use crate::db::{QueryError, QueryResult}; + +/// Holds the jar agnostic state of the database. +#[derive(Debug, Default)] +pub struct DbRuntime { + /// The cancellation token source used to signal other works that the queries should be aborted and + /// exit at the next possible point. + cancellation_token: CancellationTokenSource, +} + +impl DbRuntime { + pub(super) fn snapshot(&self) -> Self { + Self { + cancellation_token: self.cancellation_token.clone(), + } + } + + /// Cancels the pending queries of other workers. The current worker cannot have any pending + /// queries because we're holding a mutable reference to the runtime. + pub(super) fn cancel_other_workers(&mut self) { + self.cancellation_token.cancel(); + // Set a new cancellation token so that we're in a non-cancelled state again when running the next + // query. + self.cancellation_token = CancellationTokenSource::default(); + } + + /// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise. + pub(super) fn cancelled(&self) -> QueryResult<()> { + if self.cancellation_token.is_cancelled() { + Err(QueryError::Cancelled) + } else { + Ok(()) + } + } + + /// Returns `true` if the queries have been cancelled. + pub(super) fn is_cancelled(&self) -> bool { + self.cancellation_token.is_cancelled() + } +} diff --git a/crates/red_knot/src/db/storage.rs b/crates/red_knot/src/db/storage.rs new file mode 100644 index 0000000000..afb57e3230 --- /dev/null +++ b/crates/red_knot/src/db/storage.rs @@ -0,0 +1,117 @@ +use std::fmt::Formatter; +use std::sync::Arc; + +use crossbeam::sync::WaitGroup; + +use crate::db::query::QueryResult; +use crate::db::runtime::DbRuntime; +use crate::db::{HasJars, ParallelDatabase}; + +/// Stores the jars of a database and the state for each worker. +/// +/// Today, all state is shared across all workers, but it may be desired to store data per worker in the future. +pub struct JarsStorage +where + T: HasJars + Sized, +{ + // It's important that `jars_wait_group` is declared after `jars` to ensure that `jars` is dropped first. + // See https://doc.rust-lang.org/reference/destructors.html + /// Stores the jars of the database. + jars: Arc, + + /// Used to count the references to `jars`. Allows implementing `jars_mut` without requiring to clone `jars`. + jars_wait_group: WaitGroup, + + /// The data agnostic state. + runtime: DbRuntime, +} + +impl JarsStorage +where + Db: HasJars, +{ + pub(super) fn new() -> Self { + Self { + jars: Arc::new(Db::Jars::default()), + jars_wait_group: WaitGroup::default(), + runtime: DbRuntime::default(), + } + } + + /// Creates a snapshot of the jars. + /// + /// Creating the snapshot is cheap because it doesn't clone the jars, it only increments a ref counter. + #[must_use] + pub fn snapshot(&self) -> JarsStorage + where + Db: ParallelDatabase, + { + Self { + jars: self.jars.clone(), + jars_wait_group: self.jars_wait_group.clone(), + runtime: self.runtime.snapshot(), + } + } + + pub(crate) fn jars(&self) -> QueryResult<&Db::Jars> { + self.runtime.cancelled()?; + Ok(&self.jars) + } + + /// Returns a mutable reference to the jars without cloning their content. + /// + /// The method cancels any pending queries of other works and waits for them to complete so that + /// this instance is the only instance holding a reference to the jars. + pub(crate) fn jars_mut(&mut self) -> &mut Db::Jars { + // We have a mutable ref here, so no more workers can be spawned between calling this function and taking the mut ref below. + self.cancel_other_workers(); + + // Now all other references to `self.jars` should have been released. We can now safely return a mutable reference + // to the Arc's content. + let jars = + Arc::get_mut(&mut self.jars).expect("All references to jars should have been released"); + + jars + } + + pub(crate) fn runtime(&self) -> &DbRuntime { + &self.runtime + } + + pub(crate) fn runtime_mut(&mut self) -> &mut DbRuntime { + // Note: This method may need to use a similar trick to `jars_mut` if `DbRuntime` is ever to store data that is shared between workers. + &mut self.runtime + } + + #[tracing::instrument(level = "trace", skip(self))] + fn cancel_other_workers(&mut self) { + self.runtime.cancel_other_workers(); + + // Wait for all other works to complete. + let existing_wait = std::mem::take(&mut self.jars_wait_group); + existing_wait.wait(); + } +} + +impl Default for JarsStorage +where + Db: HasJars, +{ + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for JarsStorage +where + T: HasJars, + ::Jars: std::fmt::Debug, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SharedStorage") + .field("jars", &self.jars) + .field("jars_wait_group", &self.jars_wait_group) + .field("runtime", &self.runtime) + .finish() + } +} diff --git a/crates/red_knot/src/lib.rs b/crates/red_knot/src/lib.rs index 574ab81141..107073a3b9 100644 --- a/crates/red_knot/src/lib.rs +++ b/crates/red_knot/src/lib.rs @@ -27,7 +27,7 @@ pub(crate) type FxDashMap = dashmap::DashMap = dashmap::DashSet>; pub(crate) type FxIndexSet = indexmap::set::IndexSet>; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Workspace { /// TODO this should be a resolved path. We should probably use a newtype wrapper that guarantees that /// PATH is a UTF-8 path and is normalized. diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index 2de07f9531..0a000d577d 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -1,12 +1,15 @@ use std::cell::RefCell; use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use std::time::Duration; use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{ModModule, StringLiteral}; use crate::cache::KeyValueCache; -use crate::db::{HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar}; +use crate::db::{ + HasJar, ParallelDatabase, QueryResult, SemanticDb, SemanticJar, SourceDb, SourceJar, +}; use crate::files::FileId; use crate::parse::Parsed; use crate::source::Source; @@ -14,19 +17,28 @@ use crate::symbols::{Definition, SymbolId, SymbolTable}; use crate::types::Type; #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_syntax(db: &Db, file_id: FileId) -> Diagnostics +pub(crate) fn lint_syntax(db: &Db, file_id: FileId) -> QueryResult where - Db: SourceDb + HasJar, + Db: SourceDb + HasJar + ParallelDatabase, { - let storage = &db.jar().lint_syntax; + let storage = &db.jar()?.lint_syntax; + + #[allow(clippy::print_stdout)] + if std::env::var("RED_KNOT_SLOW_LINT").is_ok() { + for i in 0..10 { + db.cancelled()?; + println!("RED_KNOT_SLOW_LINT is set, sleeping for {i}/10 seconds"); + std::thread::sleep(Duration::from_secs(1)); + } + } storage.get(&file_id, |file_id| { let mut diagnostics = Vec::new(); - let source = db.source(*file_id); + let source = db.source(*file_id)?; lint_lines(source.text(), &mut diagnostics); - let parsed = db.parse(*file_id); + let parsed = db.parse(*file_id)?; if parsed.errors().is_empty() { let ast = parsed.ast(); @@ -41,7 +53,7 @@ where diagnostics.extend(parsed.errors().iter().map(std::string::ToString::to_string)); } - Diagnostics::from(diagnostics) + Ok(Diagnostics::from(diagnostics)) }) } @@ -63,16 +75,16 @@ fn lint_lines(source: &str, diagnostics: &mut Vec) { } #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_semantic(db: &Db, file_id: FileId) -> Diagnostics +pub(crate) fn lint_semantic(db: &Db, file_id: FileId) -> QueryResult where Db: SemanticDb + HasJar, { - let storage = &db.jar().lint_semantic; + let storage = &db.jar()?.lint_semantic; storage.get(&file_id, |file_id| { - let source = db.source(*file_id); - let parsed = db.parse(*file_id); - let symbols = db.symbol_table(*file_id); + let source = db.source(*file_id)?; + let parsed = db.parse(*file_id)?; + let symbols = db.symbol_table(*file_id)?; let context = SemanticLintContext { file_id: *file_id, @@ -83,25 +95,25 @@ where diagnostics: RefCell::new(Vec::new()), }; - lint_unresolved_imports(&context); + lint_unresolved_imports(&context)?; - Diagnostics::from(context.diagnostics.take()) + Ok(Diagnostics::from(context.diagnostics.take())) }) } -fn lint_unresolved_imports(context: &SemanticLintContext) { +fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> { // TODO: Consider iterating over the dependencies (imports) only instead of all definitions. for (symbol, definition) in context.symbols().all_definitions() { match definition { Definition::Import(import) => { - let ty = context.eval_symbol(symbol); + let ty = context.infer_symbol_type(symbol)?; if ty.is_unknown() { context.push_diagnostic(format!("Unresolved module {}", import.module)); } } Definition::ImportFrom(import) => { - let ty = context.eval_symbol(symbol); + let ty = context.infer_symbol_type(symbol)?; if ty.is_unknown() { let module_name = import.module().map(Deref::deref).unwrap_or_default(); @@ -126,6 +138,8 @@ fn lint_unresolved_imports(context: &SemanticLintContext) { _ => {} } } + + Ok(()) } pub struct SemanticLintContext<'a> { @@ -154,7 +168,7 @@ impl<'a> SemanticLintContext<'a> { &self.symbols } - pub fn eval_symbol(&self, symbol_id: SymbolId) -> Type { + pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult { self.db.infer_symbol_type(self.file_id, symbol_id) } diff --git a/crates/red_knot/src/main.rs b/crates/red_knot/src/main.rs index 9b1c0c0d03..36a3e4ea7a 100644 --- a/crates/red_knot/src/main.rs +++ b/crates/red_knot/src/main.rs @@ -4,6 +4,7 @@ use std::collections::hash_map::Entry; use std::path::Path; use std::sync::Mutex; +use crossbeam::channel as crossbeam_channel; use rustc_hash::FxHashMap; use tracing::subscriber::Interest; use tracing::{Level, Metadata}; @@ -12,11 +13,10 @@ use tracing_subscriber::layer::{Context, Filter, SubscriberExt}; use tracing_subscriber::{Layer, Registry}; use tracing_tree::time::Uptime; -use red_knot::cancellation::CancellationTokenSource; -use red_knot::db::{HasJar, SourceDb, SourceJar}; +use red_knot::db::{HasJar, ParallelDatabase, QueryError, SemanticDb, SourceDb, SourceJar}; use red_knot::files::FileId; use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind}; -use red_knot::program::check::{CheckError, RayonCheckScheduler}; +use red_knot::program::check::RayonCheckScheduler; use red_knot::program::{FileChange, FileChangeKind, Program}; use red_knot::watch::FileWatcher; use red_knot::Workspace; @@ -51,7 +51,8 @@ fn main() -> anyhow::Result<()> { workspace.root().to_path_buf(), ModuleSearchPathKind::FirstParty, ); - let mut program = Program::new(workspace, vec![workspace_search_path]); + let mut program = Program::new(workspace); + program.set_module_search_paths(vec![workspace_search_path]); let entry_id = program.file_id(entry_point); program.workspace_mut().open_file(entry_id); @@ -82,7 +83,7 @@ fn main() -> anyhow::Result<()> { main_loop.run(&mut program); - let source_jar: &SourceJar = program.jar(); + let source_jar: &SourceJar = program.jar().unwrap(); dbg!(source_jar.parsed.statistics()); dbg!(source_jar.sources.statistics()); @@ -101,10 +102,9 @@ impl MainLoop { let (main_loop_sender, main_loop_receiver) = crossbeam_channel::bounded(1); let mut orchestrator = Orchestrator { - pending_analysis: None, receiver: orchestrator_receiver, sender: main_loop_sender.clone(), - aggregated_changes: AggregatedChanges::default(), + revision: 0, }; std::thread::spawn(move || { @@ -137,34 +137,32 @@ impl MainLoop { tracing::trace!("Main Loop: Tick"); match message { - MainLoopMessage::CheckProgram => { - // Remove mutability from program. - let program = &*program; - let run_cancellation_token_source = CancellationTokenSource::new(); - let run_cancellation_token = run_cancellation_token_source.token(); - let sender = &self.orchestrator_sender; + MainLoopMessage::CheckProgram { revision } => { + let program = program.snapshot(); + let sender = self.orchestrator_sender.clone(); - sender - .send(OrchestratorMessage::CheckProgramStarted { - cancellation_token: run_cancellation_token_source, - }) - .unwrap(); + // Spawn a new task that checks the program. This needs to be done in a separate thread + // to prevent blocking the main loop here. + rayon::spawn(move || { + rayon::in_place_scope(|scope| { + let scheduler = RayonCheckScheduler::new(&program, scope); - rayon::in_place_scope(|scope| { - let scheduler = RayonCheckScheduler::new(program, scope); - - let result = program.check(&scheduler, run_cancellation_token); - match result { - Ok(result) => sender - .send(OrchestratorMessage::CheckProgramCompleted(result)) - .unwrap(), - Err(CheckError::Cancelled) => sender - .send(OrchestratorMessage::CheckProgramCancelled) - .unwrap(), - } + match program.check(&scheduler) { + Ok(result) => { + sender + .send(OrchestratorMessage::CheckProgramCompleted { + diagnostics: result, + revision, + }) + .unwrap(); + } + Err(QueryError::Cancelled) => {} + } + }); }); } MainLoopMessage::ApplyChanges(changes) => { + // Automatically cancels any pending queries and waits for them to complete. program.apply_changes(changes.iter()); } MainLoopMessage::CheckCompleted(diagnostics) => { @@ -211,13 +209,11 @@ impl MainLoopCancellationToken { } struct Orchestrator { - aggregated_changes: AggregatedChanges, - pending_analysis: Option, - /// Sends messages to the main loop. sender: crossbeam_channel::Sender, /// Receives messages from the main loop. receiver: crossbeam_channel::Receiver, + revision: usize, } impl Orchestrator { @@ -225,51 +221,33 @@ impl Orchestrator { while let Ok(message) = self.receiver.recv() { match message { OrchestratorMessage::Run => { - self.pending_analysis = None; - self.sender.send(MainLoopMessage::CheckProgram).unwrap(); - } - - OrchestratorMessage::CheckProgramStarted { cancellation_token } => { - debug_assert!(self.pending_analysis.is_none()); - - self.pending_analysis = Some(PendingAnalysisState { cancellation_token }); - } - - OrchestratorMessage::CheckProgramCompleted(diagnostics) => { - self.pending_analysis - .take() - .expect("Expected a pending analysis."); - self.sender - .send(MainLoopMessage::CheckCompleted(diagnostics)) + .send(MainLoopMessage::CheckProgram { + revision: self.revision, + }) .unwrap(); } - OrchestratorMessage::CheckProgramCancelled => { - self.pending_analysis - .take() - .expect("Expected a pending analysis."); - - self.debounce_changes(); + OrchestratorMessage::CheckProgramCompleted { + diagnostics, + revision, + } => { + // Only take the diagnostics if they are for the latest revision. + if self.revision == revision { + self.sender + .send(MainLoopMessage::CheckCompleted(diagnostics)) + .unwrap(); + } else { + tracing::debug!("Discarding diagnostics for outdated revision {revision} (current: {}).", self.revision); + } } OrchestratorMessage::FileChanges(changes) => { // Request cancellation, but wait until all analysis tasks have completed to // avoid stale messages in the next main loop. - let pending = if let Some(pending_state) = self.pending_analysis.as_ref() { - pending_state.cancellation_token.cancel(); - true - } else { - false - }; - self.aggregated_changes.extend(changes); - - // If there are no pending analysis tasks, apply the file changes. Otherwise - // keep running until all file checks have completed. - if !pending { - self.debounce_changes(); - } + self.revision += 1; + self.debounce_changes(changes); } OrchestratorMessage::Shutdown => { return self.shutdown(); @@ -278,8 +256,9 @@ impl Orchestrator { } } - fn debounce_changes(&mut self) { - debug_assert!(self.pending_analysis.is_none()); + fn debounce_changes(&self, changes: Vec) { + let mut aggregated_changes = AggregatedChanges::default(); + aggregated_changes.extend(changes); loop { // Consume possibly incoming file change messages before running a new analysis, but don't wait for more than 100ms. @@ -290,10 +269,12 @@ impl Orchestrator { return self.shutdown(); } Ok(OrchestratorMessage::FileChanges(file_changes)) => { - self.aggregated_changes.extend(file_changes); + aggregated_changes.extend(file_changes); } - Ok(OrchestratorMessage::CheckProgramStarted {..}| OrchestratorMessage::CheckProgramCompleted(_) | OrchestratorMessage::CheckProgramCancelled) => unreachable!("No program check should be running while debouncing changes."), + Ok(OrchestratorMessage::CheckProgramCompleted { .. })=> { + // disregard any outdated completion message. + } Ok(OrchestratorMessage::Run) => unreachable!("The orchestrator is already running."), Err(_) => { @@ -302,10 +283,10 @@ impl Orchestrator { } } }, - default(std::time::Duration::from_millis(100)) => { - // No more file changes after 100 ms, send the changes and schedule a new analysis - self.sender.send(MainLoopMessage::ApplyChanges(std::mem::take(&mut self.aggregated_changes))).unwrap(); - self.sender.send(MainLoopMessage::CheckProgram).unwrap(); + default(std::time::Duration::from_millis(10)) => { + // No more file changes after 10 ms, send the changes and schedule a new analysis + self.sender.send(MainLoopMessage::ApplyChanges(aggregated_changes)).unwrap(); + self.sender.send(MainLoopMessage::CheckProgram { revision: self.revision}).unwrap(); return; } } @@ -318,15 +299,10 @@ impl Orchestrator { } } -#[derive(Debug)] -struct PendingAnalysisState { - cancellation_token: CancellationTokenSource, -} - /// Message sent from the orchestrator to the main loop. #[derive(Debug)] enum MainLoopMessage { - CheckProgram, + CheckProgram { revision: usize }, CheckCompleted(Vec), ApplyChanges(AggregatedChanges), Exit, @@ -337,11 +313,10 @@ enum OrchestratorMessage { Run, Shutdown, - CheckProgramStarted { - cancellation_token: CancellationTokenSource, + CheckProgramCompleted { + diagnostics: Vec, + revision: usize, }, - CheckProgramCompleted(Vec), - CheckProgramCancelled, FileChanges(Vec), } diff --git a/crates/red_knot/src/module.rs b/crates/red_knot/src/module.rs index 8e6f367599..cf0d9236b9 100644 --- a/crates/red_knot/src/module.rs +++ b/crates/red_knot/src/module.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use dashmap::mapref::entry::Entry; use smol_str::SmolStr; -use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; use crate::symbols::Dependency; use crate::FxDashMap; @@ -17,44 +17,48 @@ use crate::FxDashMap; pub struct Module(u32); impl Module { - pub fn name(&self, db: &Db) -> ModuleName + pub fn name(&self, db: &Db) -> QueryResult where Db: HasJar, { - let modules = &db.jar().module_resolver; + let modules = &db.jar()?.module_resolver; - modules.modules.get(self).unwrap().name.clone() + Ok(modules.modules.get(self).unwrap().name.clone()) } - pub fn path(&self, db: &Db) -> ModulePath + pub fn path(&self, db: &Db) -> QueryResult where Db: HasJar, { - let modules = &db.jar().module_resolver; + let modules = &db.jar()?.module_resolver; - modules.modules.get(self).unwrap().path.clone() + Ok(modules.modules.get(self).unwrap().path.clone()) } - pub fn kind(&self, db: &Db) -> ModuleKind + pub fn kind(&self, db: &Db) -> QueryResult where Db: HasJar, { - let modules = &db.jar().module_resolver; + let modules = &db.jar()?.module_resolver; - modules.modules.get(self).unwrap().kind + Ok(modules.modules.get(self).unwrap().kind) } - pub fn resolve_dependency(&self, db: &Db, dependency: &Dependency) -> Option + pub fn resolve_dependency( + &self, + db: &Db, + dependency: &Dependency, + ) -> QueryResult> where Db: HasJar, { let (level, module) = match dependency { - Dependency::Module(module) => return Some(module.clone()), + Dependency::Module(module) => return Ok(Some(module.clone())), Dependency::Relative { level, module } => (*level, module.as_deref()), }; - let name = self.name(db); - let kind = self.kind(db); + let name = self.name(db)?; + let kind = self.kind(db)?; let mut components = name.components().peekable(); @@ -67,7 +71,9 @@ impl Module { // Skip over the relative parts. for _ in start..level.get() { - components.next_back()?; + if components.next_back().is_none() { + return Ok(None); + } } let mut name = String::new(); @@ -80,11 +86,11 @@ impl Module { name.push_str(part); } - if name.is_empty() { + Ok(if name.is_empty() { None } else { Some(ModuleName(SmolStr::new(name))) - } + }) } } @@ -238,20 +244,25 @@ pub struct ModuleData { /// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query. /// For this to work with salsa, it would be necessary to intern all `ModuleName`s. #[tracing::instrument(level = "debug", skip(db))] -pub fn resolve_module(db: &Db, name: ModuleName) -> Option +pub fn resolve_module(db: &Db, name: ModuleName) -> QueryResult> where Db: SemanticDb + HasJar, { let jar = db.jar(); - let modules = &jar.module_resolver; + let modules = &jar?.module_resolver; let entry = modules.by_name.entry(name.clone()); match entry { - Entry::Occupied(entry) => Some(*entry.get()), + Entry::Occupied(entry) => Ok(Some(*entry.get())), Entry::Vacant(entry) => { - let (root_path, absolute_path, kind) = resolve_name(&name, &modules.search_paths)?; - let normalized = absolute_path.canonicalize().ok()?; + let Some((root_path, absolute_path, kind)) = resolve_name(&name, &modules.search_paths) + else { + return Ok(None); + }; + let Ok(normalized) = absolute_path.canonicalize() else { + return Ok(None); + }; let file_id = db.file_id(&normalized); let path = ModulePath::new(root_path.clone(), file_id); @@ -277,7 +288,7 @@ where entry.insert_entry(id); - Some(id) + Ok(Some(id)) } } } @@ -286,7 +297,7 @@ where /// /// Returns `None` if the file is not a module in `sys.path`. #[tracing::instrument(level = "debug", skip(db))] -pub fn file_to_module(db: &Db, file: FileId) -> Option +pub fn file_to_module(db: &Db, file: FileId) -> QueryResult> where Db: SemanticDb + HasJar, { @@ -298,34 +309,42 @@ where /// /// Returns `None` if the path is not a module in `sys.path`. #[tracing::instrument(level = "debug", skip(db))] -pub fn path_to_module(db: &Db, path: &Path) -> Option +pub fn path_to_module(db: &Db, path: &Path) -> QueryResult> where Db: SemanticDb + HasJar, { - let jar = db.jar(); + let jar = db.jar()?; let modules = &jar.module_resolver; debug_assert!(path.is_absolute()); if let Some(existing) = modules.by_path.get(path) { - return Some(*existing); + return Ok(Some(*existing)); } - let (root_path, relative_path) = modules.search_paths.iter().find_map(|root| { + let Some((root_path, relative_path)) = modules.search_paths.iter().find_map(|root| { let relative_path = path.strip_prefix(root.path()).ok()?; Some((root.clone(), relative_path)) - })?; + }) else { + return Ok(None); + }; - let module_name = ModuleName::from_relative_path(relative_path)?; + let Some(module_name) = ModuleName::from_relative_path(relative_path) else { + return Ok(None); + }; // Resolve the module name to see if Python would resolve the name to the same path. // If it doesn't, then that means that multiple modules have the same in different // root paths, but that the module corresponding to the past path is in a lower priority path, // in which case we ignore it. - let module_id = resolve_module(db, module_name)?; - let module_path = module_id.path(db); + let Some(module_id) = resolve_module(db, module_name)? else { + return Ok(None); + }; + let module_path = module_id.path(db)?; if module_path.root() == &root_path { - let normalized = path.canonicalize().ok()?; + let Ok(normalized) = path.canonicalize() else { + return Ok(None); + }; let interned_normalized = db.file_id(&normalized); if interned_normalized != module_path.file() { @@ -336,15 +355,15 @@ where // ``` // The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`. // That means we need to ignore `src/foo.py` even though it resolves to the same module name. - return None; + return Ok(None); } // Path has been inserted by `resolved` - Some(module_id) + Ok(Some(module_id)) } else { // This path is for a module with the same name but in a module search path with a lower priority. // Ignore it. - None + Ok(None) } } @@ -378,7 +397,7 @@ where // TODO This needs tests // Note: Intentionally by-pass caching here. Module should not be in the cache yet. - let module = path_to_module(db, path)?; + let module = path_to_module(db, path).ok()??; // The code below is to handle the addition of `__init__.py` files. // When an `__init__.py` file is added, we need to remove all modules that are part of the same package. @@ -392,7 +411,7 @@ where return Some((module, Vec::new())); } - let Some(parent_name) = module.name(db).parent() else { + let Some(parent_name) = module.name(db).ok()?.parent() else { return Some((module, Vec::new())); }; @@ -691,7 +710,7 @@ mod tests { } #[test] - fn first_party_module() -> std::io::Result<()> { + fn first_party_module() -> anyhow::Result<()> { let TestCase { db, src, @@ -702,22 +721,22 @@ mod tests { let foo_path = src.path().join("foo.py"); std::fs::write(&foo_path, "print('Hello, world!')")?; - let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); - assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo"))); + assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo"))?); - assert_eq!(ModuleName::new("foo"), foo_module.name(&db)); - assert_eq!(&src, foo_module.path(&db).root()); - assert_eq!(ModuleKind::Module, foo_module.kind(&db)); - assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file())); + assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?); + assert_eq!(&src, foo_module.path(&db)?.root()); + assert_eq!(ModuleKind::Module, foo_module.kind(&db)?); + assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo_path)); + assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?); Ok(()) } #[test] - fn resolve_package() -> std::io::Result<()> { + fn resolve_package() -> anyhow::Result<()> { let TestCase { src, db, @@ -730,22 +749,22 @@ mod tests { std::fs::create_dir(&foo_dir)?; std::fs::write(&foo_path, "print('Hello, world!')")?; - let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); - assert_eq!(ModuleName::new("foo"), foo_module.name(&db)); - assert_eq!(&src, foo_module.path(&db).root()); - assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file())); + assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?); + assert_eq!(&src, foo_module.path(&db)?.root()); + assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo_path)); + assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?); // Resolving by directory doesn't resolve to the init file. - assert_eq!(None, db.path_to_module(&foo_dir)); + assert_eq!(None, db.path_to_module(&foo_dir)?); Ok(()) } #[test] - fn package_priority_over_module() -> std::io::Result<()> { + fn package_priority_over_module() -> anyhow::Result<()> { let TestCase { db, temp_dir: _temp_dir, @@ -761,20 +780,20 @@ mod tests { let foo_py = src.path().join("foo.py"); std::fs::write(&foo_py, "print('Hello, world!')")?; - let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); - assert_eq!(&src, foo_module.path(&db).root()); - assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db).file())); - assert_eq!(ModuleKind::Package, foo_module.kind(&db)); + assert_eq!(&src, foo_module.path(&db)?.root()); + assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db)?.file())); + assert_eq!(ModuleKind::Package, foo_module.kind(&db)?); - assert_eq!(Some(foo_module), db.path_to_module(&foo_init)); - assert_eq!(None, db.path_to_module(&foo_py)); + assert_eq!(Some(foo_module), db.path_to_module(&foo_init)?); + assert_eq!(None, db.path_to_module(&foo_py)?); Ok(()) } #[test] - fn typing_stub_over_module() -> std::io::Result<()> { + fn typing_stub_over_module() -> anyhow::Result<()> { let TestCase { db, src, @@ -787,19 +806,19 @@ mod tests { std::fs::write(&foo_stub, "x: int")?; std::fs::write(&foo_py, "print('Hello, world!')")?; - let foo = db.resolve_module(ModuleName::new("foo")).unwrap(); + let foo = db.resolve_module(ModuleName::new("foo"))?.unwrap(); - assert_eq!(&src, foo.path(&db).root()); - assert_eq!(&foo_stub, &*db.file_path(foo.path(&db).file())); + assert_eq!(&src, foo.path(&db)?.root()); + assert_eq!(&foo_stub, &*db.file_path(foo.path(&db)?.file())); - assert_eq!(Some(foo), db.path_to_module(&foo_stub)); - assert_eq!(None, db.path_to_module(&foo_py)); + assert_eq!(Some(foo), db.path_to_module(&foo_stub)?); + assert_eq!(None, db.path_to_module(&foo_py)?); Ok(()) } #[test] - fn sub_packages() -> std::io::Result<()> { + fn sub_packages() -> anyhow::Result<()> { let TestCase { db, src, @@ -816,18 +835,18 @@ mod tests { std::fs::write(bar.join("__init__.py"), "")?; std::fs::write(&baz, "print('Hello, world!')")?; - let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz")).unwrap(); + let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz"))?.unwrap(); - assert_eq!(&src, baz_module.path(&db).root()); - assert_eq!(&baz, &*db.file_path(baz_module.path(&db).file())); + assert_eq!(&src, baz_module.path(&db)?.root()); + assert_eq!(&baz, &*db.file_path(baz_module.path(&db)?.file())); - assert_eq!(Some(baz_module), db.path_to_module(&baz)); + assert_eq!(Some(baz_module), db.path_to_module(&baz)?); Ok(()) } #[test] - fn namespace_package() -> std::io::Result<()> { + fn namespace_package() -> anyhow::Result<()> { let TestCase { db, temp_dir: _, @@ -863,21 +882,21 @@ mod tests { std::fs::write(&two, "print('Hello, world!')")?; let one_module = db - .resolve_module(ModuleName::new("parent.child.one")) + .resolve_module(ModuleName::new("parent.child.one"))? .unwrap(); - assert_eq!(Some(one_module), db.path_to_module(&one)); + assert_eq!(Some(one_module), db.path_to_module(&one)?); let two_module = db - .resolve_module(ModuleName::new("parent.child.two")) + .resolve_module(ModuleName::new("parent.child.two"))? .unwrap(); - assert_eq!(Some(two_module), db.path_to_module(&two)); + assert_eq!(Some(two_module), db.path_to_module(&two)?); Ok(()) } #[test] - fn regular_package_in_namespace_package() -> std::io::Result<()> { + fn regular_package_in_namespace_package() -> anyhow::Result<()> { let TestCase { db, temp_dir: _, @@ -914,17 +933,20 @@ mod tests { std::fs::write(two, "print('Hello, world!')")?; let one_module = db - .resolve_module(ModuleName::new("parent.child.one")) + .resolve_module(ModuleName::new("parent.child.one"))? .unwrap(); - assert_eq!(Some(one_module), db.path_to_module(&one)); + assert_eq!(Some(one_module), db.path_to_module(&one)?); - assert_eq!(None, db.resolve_module(ModuleName::new("parent.child.two"))); + assert_eq!( + None, + db.resolve_module(ModuleName::new("parent.child.two"))? + ); Ok(()) } #[test] - fn module_search_path_priority() -> std::io::Result<()> { + fn module_search_path_priority() -> anyhow::Result<()> { let TestCase { db, src, @@ -938,20 +960,20 @@ mod tests { std::fs::write(&foo_src, "")?; std::fs::write(&foo_site_packages, "")?; - let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); + let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); - assert_eq!(&src, foo_module.path(&db).root()); - assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db).file())); + assert_eq!(&src, foo_module.path(&db)?.root()); + assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo_src)); - assert_eq!(None, db.path_to_module(&foo_site_packages)); + assert_eq!(Some(foo_module), db.path_to_module(&foo_src)?); + assert_eq!(None, db.path_to_module(&foo_site_packages)?); Ok(()) } #[test] #[cfg(target_family = "unix")] - fn symlink() -> std::io::Result<()> { + fn symlink() -> anyhow::Result<()> { let TestCase { db, src, @@ -965,28 +987,28 @@ mod tests { std::fs::write(&foo, "")?; std::os::unix::fs::symlink(&foo, &bar)?; - let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); - let bar_module = db.resolve_module(ModuleName::new("bar")).unwrap(); + let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); + let bar_module = db.resolve_module(ModuleName::new("bar"))?.unwrap(); assert_ne!(foo_module, bar_module); - assert_eq!(&src, foo_module.path(&db).root()); - assert_eq!(&foo, &*db.file_path(foo_module.path(&db).file())); + assert_eq!(&src, foo_module.path(&db)?.root()); + assert_eq!(&foo, &*db.file_path(foo_module.path(&db)?.file())); // Bar has a different name but it should point to the same file. - assert_eq!(&src, bar_module.path(&db).root()); - assert_eq!(foo_module.path(&db).file(), bar_module.path(&db).file()); - assert_eq!(&foo, &*db.file_path(bar_module.path(&db).file())); + assert_eq!(&src, bar_module.path(&db)?.root()); + assert_eq!(foo_module.path(&db)?.file(), bar_module.path(&db)?.file()); + assert_eq!(&foo, &*db.file_path(bar_module.path(&db)?.file())); - assert_eq!(Some(foo_module), db.path_to_module(&foo)); - assert_eq!(Some(bar_module), db.path_to_module(&bar)); + assert_eq!(Some(foo_module), db.path_to_module(&foo)?); + assert_eq!(Some(bar_module), db.path_to_module(&bar)?); Ok(()) } #[test] - fn resolve_dependency() -> std::io::Result<()> { + fn resolve_dependency() -> anyhow::Result<()> { let TestCase { src, db, @@ -1002,8 +1024,8 @@ mod tests { std::fs::write(foo_path, "from .bar import test")?; std::fs::write(bar_path, "test = 'Hello world'")?; - let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); - let bar_module = db.resolve_module(ModuleName::new("foo.bar")).unwrap(); + let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap(); + let bar_module = db.resolve_module(ModuleName::new("foo.bar"))?.unwrap(); // `from . import bar` in `foo/__init__.py` resolves to `foo` assert_eq!( @@ -1014,13 +1036,13 @@ mod tests { level: NonZeroU32::new(1).unwrap(), module: None, } - ) + )? ); // `from baz import bar` in `foo/__init__.py` should resolve to `baz.py` assert_eq!( Some(ModuleName::new("baz")), - foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz"))) + foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))? ); // from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py` @@ -1032,7 +1054,7 @@ mod tests { level: NonZeroU32::new(1).unwrap(), module: Some(ModuleName::new("bar")) } - ) + )? ); // from .. import test in `foo/__init__.py` resolves to `` which is not a module @@ -1044,7 +1066,7 @@ mod tests { level: NonZeroU32::new(2).unwrap(), module: None } - ) + )? ); // `from . import test` in `foo/bar.py` resolves to `foo` @@ -1056,13 +1078,13 @@ mod tests { level: NonZeroU32::new(1).unwrap(), module: None } - ) + )? ); // `from baz import test` in `foo/bar.py` resolves to `baz` assert_eq!( Some(ModuleName::new("baz")), - bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz"))) + bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))? ); // `from .baz import test` in `foo/bar.py` resolves to `foo.baz`. @@ -1074,7 +1096,7 @@ mod tests { level: NonZeroU32::new(1).unwrap(), module: Some(ModuleName::new("baz")) } - ) + )? ); Ok(()) diff --git a/crates/red_knot/src/parse.rs b/crates/red_knot/src/parse.rs index 20df9f4c8c..e76cd06706 100644 --- a/crates/red_knot/src/parse.rs +++ b/crates/red_knot/src/parse.rs @@ -6,7 +6,7 @@ use ruff_python_parser::{Mode, ParseError}; use ruff_text_size::{Ranged, TextRange}; use crate::cache::KeyValueCache; -use crate::db::{HasJar, SourceDb, SourceJar}; +use crate::db::{HasJar, QueryResult, SourceDb, SourceJar}; use crate::files::FileId; #[derive(Debug, Clone, PartialEq)] @@ -64,16 +64,16 @@ impl Parsed { } #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn parse(db: &Db, file_id: FileId) -> Parsed +pub(crate) fn parse(db: &Db, file_id: FileId) -> QueryResult where Db: SourceDb + HasJar, { - let parsed = db.jar(); + let parsed = db.jar()?; parsed.parsed.get(&file_id, |file_id| { - let source = db.source(*file_id); + let source = db.source(*file_id)?; - Parsed::from_text(source.text()) + Ok(Parsed::from_text(source.text())) }) } diff --git a/crates/red_knot/src/program/check.rs b/crates/red_knot/src/program/check.rs index 1299760cb5..258f9d0062 100644 --- a/crates/red_knot/src/program/check.rs +++ b/crates/red_knot/src/program/check.rs @@ -1,10 +1,9 @@ use std::num::NonZeroUsize; -use rayon::max_num_threads; +use rayon::{current_num_threads, yield_local}; use rustc_hash::FxHashSet; -use crate::cancellation::CancellationToken; -use crate::db::{SemanticDb, SourceDb}; +use crate::db::{Database, QueryError, QueryResult, SemanticDb, SourceDb}; use crate::files::FileId; use crate::lint::Diagnostics; use crate::program::Program; @@ -13,42 +12,37 @@ use crate::symbols::Dependency; impl Program { /// Checks all open files in the workspace and its dependencies. #[tracing::instrument(level = "debug", skip_all)] - pub fn check( - &self, - scheduler: &dyn CheckScheduler, - cancellation_token: CancellationToken, - ) -> Result, CheckError> { - let check_loop = CheckFilesLoop::new(scheduler, cancellation_token); + pub fn check(&self, scheduler: &dyn CheckScheduler) -> QueryResult> { + self.cancelled()?; + + let check_loop = CheckFilesLoop::new(scheduler); check_loop.run(self.workspace().open_files.iter().copied()) } /// Checks a single file and its dependencies. - #[tracing::instrument(level = "debug", skip(self, scheduler, cancellation_token))] + #[tracing::instrument(level = "debug", skip(self, scheduler))] pub fn check_file( &self, file: FileId, scheduler: &dyn CheckScheduler, - cancellation_token: CancellationToken, - ) -> Result, CheckError> { - let check_loop = CheckFilesLoop::new(scheduler, cancellation_token); + ) -> QueryResult> { + self.cancelled()?; + + let check_loop = CheckFilesLoop::new(scheduler); check_loop.run([file].into_iter()) } #[tracing::instrument(level = "debug", skip(self, context))] - fn do_check_file( - &self, - file: FileId, - context: &CheckContext, - ) -> Result { - context.cancelled_ok()?; + fn do_check_file(&self, file: FileId, context: &CheckContext) -> QueryResult { + self.cancelled()?; - let symbol_table = self.symbol_table(file); + let symbol_table = self.symbol_table(file)?; let dependencies = symbol_table.dependencies(); if !dependencies.is_empty() { - let module = self.file_to_module(file); + let module = self.file_to_module(file)?; // TODO scheduling all dependencies here is wasteful if we don't infer any types on them // but I think that's unlikely, so it is okay? @@ -57,18 +51,19 @@ impl Program { for dependency in dependencies { let dependency_name = match dependency { Dependency::Module(name) => Some(name.clone()), - Dependency::Relative { .. } => module - .as_ref() - .and_then(|module| module.resolve_dependency(self, dependency)), + Dependency::Relative { .. } => match &module { + Some(module) => module.resolve_dependency(self, dependency)?, + None => None, + }, }; if let Some(dependency_name) = dependency_name { // TODO We may want to have a different check functions for non-first-party // files because we only need to index them and not check them. // Supporting non-first-party code also requires supporting typing stubs. - if let Some(dependency) = self.resolve_module(dependency_name) { - if dependency.path(self).root().kind().is_first_party() { - context.schedule_check_file(dependency.path(self).file()); + if let Some(dependency) = self.resolve_module(dependency_name)? { + if dependency.path(self)?.root().kind().is_first_party() { + context.schedule_check_file(dependency.path(self)?.file()); } } } @@ -78,8 +73,8 @@ impl Program { let mut diagnostics = Vec::new(); if self.workspace().is_file_open(file) { - diagnostics.extend_from_slice(&self.lint_syntax(file)); - diagnostics.extend_from_slice(&self.lint_semantic(file)); + diagnostics.extend_from_slice(&self.lint_syntax(file)?); + diagnostics.extend_from_slice(&self.lint_semantic(file)?); } Ok(Diagnostics::from(diagnostics)) @@ -128,10 +123,18 @@ where self.scope .spawn(move |_| child_span.in_scope(|| check_file_task.run(program))); + + if current_num_threads() == 1 { + yield_local(); + } } fn max_concurrency(&self) -> Option { - Some(NonZeroUsize::new(max_num_threads()).unwrap_or(NonZeroUsize::MIN)) + if current_num_threads() == 1 { + return None; + } + + Some(NonZeroUsize::new(current_num_threads()).unwrap_or(NonZeroUsize::MIN)) } } @@ -156,11 +159,6 @@ impl CheckScheduler for SameThreadCheckScheduler<'_> { } } -#[derive(Debug, Clone)] -pub enum CheckError { - Cancelled, -} - #[derive(Debug)] pub struct CheckFileTask { file_id: FileId, @@ -176,7 +174,7 @@ impl CheckFileTask { .sender .send(CheckFileMessage::Completed(diagnostics)) .unwrap(), - Err(CheckError::Cancelled) => self + Err(QueryError::Cancelled) => self .context .sender .send(CheckFileMessage::Cancelled) @@ -187,19 +185,12 @@ impl CheckFileTask { #[derive(Clone, Debug)] struct CheckContext { - cancellation_token: CancellationToken, - sender: crossbeam_channel::Sender, + sender: crossbeam::channel::Sender, } impl CheckContext { - fn new( - cancellation_token: CancellationToken, - sender: crossbeam_channel::Sender, - ) -> Self { - Self { - cancellation_token, - sender, - } + fn new(sender: crossbeam::channel::Sender) -> Self { + Self { sender } } /// Queues a new file for checking using the [`CheckScheduler`]. @@ -207,52 +198,36 @@ impl CheckContext { fn schedule_check_file(&self, file_id: FileId) { self.sender.send(CheckFileMessage::Queue(file_id)).unwrap(); } - - /// Returns `true` if the check has been cancelled. - fn is_cancelled(&self) -> bool { - self.cancellation_token.is_cancelled() - } - - fn cancelled_ok(&self) -> Result<(), CheckError> { - if self.is_cancelled() { - Err(CheckError::Cancelled) - } else { - Ok(()) - } - } } struct CheckFilesLoop<'a> { scheduler: &'a dyn CheckScheduler, - cancellation_token: CancellationToken, pending: usize, queued_files: FxHashSet, } impl<'a> CheckFilesLoop<'a> { - fn new(scheduler: &'a dyn CheckScheduler, cancellation_token: CancellationToken) -> Self { + fn new(scheduler: &'a dyn CheckScheduler) -> Self { Self { scheduler, - cancellation_token, - queued_files: FxHashSet::default(), pending: 0, } } - fn run(mut self, files: impl Iterator) -> Result, CheckError> { + fn run(mut self, files: impl Iterator) -> QueryResult> { let (sender, receiver) = if let Some(max_concurrency) = self.scheduler.max_concurrency() { - crossbeam_channel::bounded(max_concurrency.get()) + crossbeam::channel::bounded(max_concurrency.get()) } else { // The checks run on the current thread. That means it is necessary to store all messages // or we risk deadlocking when the main loop never gets a chance to read the messages. - crossbeam_channel::unbounded() + crossbeam::channel::unbounded() }; - let context = CheckContext::new(self.cancellation_token.clone(), sender.clone()); + let context = CheckContext::new(sender.clone()); for file in files { - self.queue_file(file, context.clone())?; + self.queue_file(file, context.clone()); } self.run_impl(receiver, &context) @@ -260,14 +235,11 @@ impl<'a> CheckFilesLoop<'a> { fn run_impl( mut self, - receiver: crossbeam_channel::Receiver, + receiver: crossbeam::channel::Receiver, context: &CheckContext, - ) -> Result, CheckError> { - if self.cancellation_token.is_cancelled() { - return Err(CheckError::Cancelled); - } - + ) -> QueryResult> { let mut result = Vec::default(); + let mut cancelled = false; for message in receiver { match message { @@ -281,30 +253,35 @@ impl<'a> CheckFilesLoop<'a> { } } CheckFileMessage::Queue(id) => { - self.queue_file(id, context.clone())?; + if !cancelled { + self.queue_file(id, context.clone()); + } } CheckFileMessage::Cancelled => { - return Err(CheckError::Cancelled); + self.pending -= 1; + cancelled = true; + + if self.pending == 0 { + break; + } } } } - Ok(result) + if cancelled { + Err(QueryError::Cancelled) + } else { + Ok(result) + } } - fn queue_file(&mut self, file_id: FileId, context: CheckContext) -> Result<(), CheckError> { - if context.is_cancelled() { - return Err(CheckError::Cancelled); - } - + fn queue_file(&mut self, file_id: FileId, context: CheckContext) { if self.queued_files.insert(file_id) { self.pending += 1; self.scheduler .check_file(CheckFileTask { file_id, context }); } - - Ok(()) } } diff --git a/crates/red_knot/src/program/mod.rs b/crates/red_knot/src/program/mod.rs index 5c9fe1eebe..17fa40d470 100644 --- a/crates/red_knot/src/program/mod.rs +++ b/crates/red_knot/src/program/mod.rs @@ -1,45 +1,35 @@ -pub mod check; - use std::path::Path; use std::sync::Arc; -use crate::db::{Db, HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar}; -use crate::files::{FileId, Files}; -use crate::lint::{ - lint_semantic, lint_syntax, Diagnostics, LintSemanticStorage, LintSyntaxStorage, +use crate::db::{ + Database, Db, DbRuntime, HasJar, HasJars, JarsStorage, ParallelDatabase, QueryResult, + SemanticDb, SemanticJar, Snapshot, SourceDb, SourceJar, }; +use crate::files::{FileId, Files}; +use crate::lint::{lint_semantic, lint_syntax, Diagnostics}; use crate::module::{ add_module, file_to_module, path_to_module, resolve_module, set_module_search_paths, Module, - ModuleData, ModuleName, ModuleResolver, ModuleSearchPath, + ModuleData, ModuleName, ModuleSearchPath, }; -use crate::parse::{parse, Parsed, ParsedStorage}; -use crate::source::{source_text, Source, SourceStorage}; -use crate::symbols::{symbol_table, SymbolId, SymbolTable, SymbolTablesStorage}; -use crate::types::{infer_symbol_type, Type, TypeStore}; +use crate::parse::{parse, Parsed}; +use crate::source::{source_text, Source}; +use crate::symbols::{symbol_table, SymbolId, SymbolTable}; +use crate::types::{infer_symbol_type, Type}; use crate::Workspace; +pub mod check; + #[derive(Debug)] pub struct Program { + jars: JarsStorage, files: Files, - source: SourceJar, - semantic: SemanticJar, workspace: Workspace, } impl Program { - pub fn new(workspace: Workspace, module_search_paths: Vec) -> Self { + pub fn new(workspace: Workspace) -> Self { Self { - source: SourceJar { - sources: SourceStorage::default(), - parsed: ParsedStorage::default(), - lint_syntax: LintSyntaxStorage::default(), - }, - semantic: SemanticJar { - module_resolver: ModuleResolver::new(module_search_paths), - symbol_tables: SymbolTablesStorage::default(), - type_store: TypeStore::default(), - lint_semantic: LintSemanticStorage::default(), - }, + jars: JarsStorage::default(), files: Files::default(), workspace, } @@ -49,17 +39,19 @@ impl Program { where I: IntoIterator, { + let files = self.files.clone(); + let (source, semantic) = self.jars_mut(); for change in changes { - self.semantic - .module_resolver - .remove_module(&self.file_path(change.id)); - self.semantic.symbol_tables.remove(&change.id); - self.source.sources.remove(&change.id); - self.source.parsed.remove(&change.id); - self.source.lint_syntax.remove(&change.id); + let file_path = files.path(change.id); + + semantic.module_resolver.remove_module(&file_path); + semantic.symbol_tables.remove(&change.id); + source.sources.remove(&change.id); + source.parsed.remove(&change.id); + source.lint_syntax.remove(&change.id); // TODO: remove all dependent modules as well - self.semantic.type_store.remove_module(change.id); - self.semantic.lint_semantic.remove(&change.id); + semantic.type_store.remove_module(change.id); + semantic.lint_semantic.remove(&change.id); } } @@ -85,41 +77,41 @@ impl SourceDb for Program { self.files.path(file_id) } - fn source(&self, file_id: FileId) -> Source { + fn source(&self, file_id: FileId) -> QueryResult { source_text(self, file_id) } - fn parse(&self, file_id: FileId) -> Parsed { + fn parse(&self, file_id: FileId) -> QueryResult { parse(self, file_id) } - fn lint_syntax(&self, file_id: FileId) -> Diagnostics { + fn lint_syntax(&self, file_id: FileId) -> QueryResult { lint_syntax(self, file_id) } } impl SemanticDb for Program { - fn resolve_module(&self, name: ModuleName) -> Option { + fn resolve_module(&self, name: ModuleName) -> QueryResult> { resolve_module(self, name) } - fn file_to_module(&self, file_id: FileId) -> Option { + fn file_to_module(&self, file_id: FileId) -> QueryResult> { file_to_module(self, file_id) } - fn path_to_module(&self, path: &Path) -> Option { + fn path_to_module(&self, path: &Path) -> QueryResult> { path_to_module(self, path) } - fn symbol_table(&self, file_id: FileId) -> Arc { + fn symbol_table(&self, file_id: FileId) -> QueryResult> { symbol_table(self, file_id) } - fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type { + fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult { infer_symbol_type(self, file_id, symbol_id) } - fn lint_semantic(&self, file_id: FileId) -> Diagnostics { + fn lint_semantic(&self, file_id: FileId) -> QueryResult { lint_semantic(self, file_id) } @@ -135,23 +127,55 @@ impl SemanticDb for Program { impl Db for Program {} +impl Database for Program { + fn runtime(&self) -> &DbRuntime { + self.jars.runtime() + } + + fn runtime_mut(&mut self) -> &mut DbRuntime { + self.jars.runtime_mut() + } +} + +impl ParallelDatabase for Program { + fn snapshot(&self) -> Snapshot { + Snapshot::new(Self { + jars: self.jars.snapshot(), + files: self.files.clone(), + workspace: self.workspace.clone(), + }) + } +} + +impl HasJars for Program { + type Jars = (SourceJar, SemanticJar); + + fn jars(&self) -> QueryResult<&Self::Jars> { + self.jars.jars() + } + + fn jars_mut(&mut self) -> &mut Self::Jars { + self.jars.jars_mut() + } +} + impl HasJar for Program { - fn jar(&self) -> &SourceJar { - &self.source + fn jar(&self) -> QueryResult<&SourceJar> { + Ok(&self.jars()?.0) } fn jar_mut(&mut self) -> &mut SourceJar { - &mut self.source + &mut self.jars_mut().0 } } impl HasJar for Program { - fn jar(&self) -> &SemanticJar { - &self.semantic + fn jar(&self) -> QueryResult<&SemanticJar> { + Ok(&self.jars()?.1) } fn jar_mut(&mut self) -> &mut SemanticJar { - &mut self.semantic + &mut self.jars_mut().1 } } diff --git a/crates/red_knot/src/source.rs b/crates/red_knot/src/source.rs index 08ad2d8aba..69092d6844 100644 --- a/crates/red_knot/src/source.rs +++ b/crates/red_knot/src/source.rs @@ -1,5 +1,5 @@ use crate::cache::KeyValueCache; -use crate::db::{HasJar, SourceDb, SourceJar}; +use crate::db::{HasJar, QueryResult, SourceDb, SourceJar}; use ruff_notebook::Notebook; use ruff_python_ast::PySourceType; use std::ops::{Deref, DerefMut}; @@ -8,11 +8,11 @@ use std::sync::Arc; use crate::files::FileId; #[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn source_text(db: &Db, file_id: FileId) -> Source +pub(crate) fn source_text(db: &Db, file_id: FileId) -> QueryResult where Db: SourceDb + HasJar, { - let sources = &db.jar().sources; + let sources = &db.jar()?.sources; sources.get(&file_id, |file_id| { let path = db.file_path(*file_id); @@ -43,7 +43,7 @@ where } }; - Source { kind } + Ok(Source { kind }) }) } diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index c1eec8de7b..182bb5f79c 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -16,22 +16,22 @@ use ruff_python_ast::visitor::preorder::PreorderVisitor; use crate::ast_ids::TypedNodeKey; use crate::cache::KeyValueCache; -use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; use crate::module::ModuleName; use crate::Name; #[allow(unreachable_pub)] #[tracing::instrument(level = "debug", skip(db))] -pub fn symbol_table(db: &Db, file_id: FileId) -> Arc +pub fn symbol_table(db: &Db, file_id: FileId) -> QueryResult> where Db: SemanticDb + HasJar, { - let jar = db.jar(); + let jar = db.jar()?; jar.symbol_tables.get(&file_id, |_| { - let parsed = db.parse(file_id); - Arc::from(SymbolTable::from_ast(parsed.ast())) + let parsed = db.parse(file_id)?; + Ok(Arc::from(SymbolTable::from_ast(parsed.ast()))) }) } diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index aba308d14a..da327c5b57 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -2,7 +2,7 @@ use ruff_python_ast::AstNode; -use crate::db::{HasJar, SemanticDb, SemanticJar}; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::module::ModuleName; use crate::symbols::{Definition, ImportFromDefinition, SymbolId}; use crate::types::Type; @@ -11,23 +11,24 @@ use ruff_python_ast as ast; // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. #[tracing::instrument(level = "trace", skip(db))] -pub fn infer_symbol_type(db: &Db, file_id: FileId, symbol_id: SymbolId) -> Type +pub fn infer_symbol_type(db: &Db, file_id: FileId, symbol_id: SymbolId) -> QueryResult where Db: SemanticDb + HasJar, { - let symbols = db.symbol_table(file_id); + let symbols = db.symbol_table(file_id)?; let defs = symbols.definitions(symbol_id); if let Some(ty) = db - .jar() + .jar()? .type_store .get_cached_symbol_type(file_id, symbol_id) { - return ty; + return Ok(ty); } // TODO handle multiple defs, conditional defs... assert_eq!(defs.len(), 1); + let type_store = &db.jar()?.type_store; let ty = match &defs[0] { Definition::ImportFrom(ImportFromDefinition { @@ -38,11 +39,11 @@ where // TODO relative imports assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); - if let Some(module) = db.resolve_module(module_name) { - let remote_file_id = module.path(db).file(); - let remote_symbols = db.symbol_table(remote_file_id); + if let Some(module) = db.resolve_module(module_name)? { + let remote_file_id = module.path(db)?.file(); + let remote_symbols = db.symbol_table(remote_file_id)?; if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) { - db.infer_symbol_type(remote_file_id, remote_symbol_id) + db.infer_symbol_type(remote_file_id, remote_symbol_id)? } else { Type::Unknown } @@ -50,71 +51,68 @@ where Type::Unknown } } - Definition::ClassDef(node_key) => db - .jar() - .type_store - .get_cached_node_type(file_id, node_key.erased()) - .unwrap_or_else(|| { - let parsed = db.parse(file_id); + Definition::ClassDef(node_key) => { + if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { + ty + } else { + let parsed = db.parse(file_id)?; let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); - let bases: Vec<_> = node - .bases() - .iter() - .map(|base_expr| infer_expr_type(db, file_id, base_expr)) - .collect(); + let mut bases = Vec::with_capacity(node.bases().len()); - let store = &db.jar().type_store; - let ty = Type::Class(store.add_class(file_id, &node.name.id, bases)); - store.cache_node_type(file_id, *node_key.erased(), ty); + for base in node.bases() { + bases.push(infer_expr_type(db, file_id, base)?); + } + + let ty = Type::Class(type_store.add_class(file_id, &node.name.id, bases)); + type_store.cache_node_type(file_id, *node_key.erased(), ty); ty - }), - Definition::FunctionDef(node_key) => db - .jar() - .type_store - .get_cached_node_type(file_id, node_key.erased()) - .unwrap_or_else(|| { - let parsed = db.parse(file_id); + } + } + Definition::FunctionDef(node_key) => { + if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { + ty + } else { + let parsed = db.parse(file_id)?; let ast = parsed.ast(); let node = node_key .resolve(ast.as_any_node_ref()) .expect("node key should resolve"); - let store = &db.jar().type_store; - let ty = store.add_function(file_id, &node.name.id).into(); - store.cache_node_type(file_id, *node_key.erased(), ty); + let ty = type_store.add_function(file_id, &node.name.id).into(); + type_store.cache_node_type(file_id, *node_key.erased(), ty); ty - }), + } + } Definition::Assignment(node_key) => { - let parsed = db.parse(file_id); + let parsed = db.parse(file_id)?; let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); // TODO handle unpacking assignment correctly - infer_expr_type(db, file_id, &node.value) + infer_expr_type(db, file_id, &node.value)? } _ => todo!("other kinds of definitions"), }; - db.jar() - .type_store - .cache_symbol_type(file_id, symbol_id, ty); + type_store.cache_symbol_type(file_id, symbol_id, ty); + // TODO record dependencies - ty + Ok(ty) } -fn infer_expr_type(db: &Db, file_id: FileId, expr: &ast::Expr) -> Type +fn infer_expr_type(db: &Db, file_id: FileId, expr: &ast::Expr) -> QueryResult where Db: SemanticDb + HasJar, { // TODO cache the resolution of the type on the node - let symbols = db.symbol_table(file_id); + let symbols = db.symbol_table(file_id)?; match expr { ast::Expr::Name(name) => { if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { db.infer_symbol_type(file_id, symbol_id) } else { - Type::Unknown + Ok(Type::Unknown) } } _ => todo!("full expression type resolution"), @@ -154,7 +152,7 @@ mod tests { } #[test] - fn follow_import_to_class() -> std::io::Result<()> { + fn follow_import_to_class() -> anyhow::Result<()> { let case = create_test()?; let db = &case.db; @@ -163,18 +161,18 @@ mod tests { std::fs::write(a_path, "from b import C as D; E = D")?; std::fs::write(b_path, "class C: pass")?; let a_file = db - .resolve_module(ModuleName::new("a")) + .resolve_module(ModuleName::new("a"))? .expect("module should be found") - .path(db) + .path(db)? .file(); - let a_syms = db.symbol_table(a_file); + let a_syms = db.symbol_table(a_file)?; let e_sym = a_syms .root_symbol_id_by_name("E") .expect("E symbol should be found"); - let ty = db.infer_symbol_type(a_file, e_sym); + let ty = db.infer_symbol_type(a_file, e_sym)?; - let jar = HasJar::::jar(db); + let jar = HasJar::::jar(db)?; assert!(matches!(ty, Type::Class(_))); assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]"); @@ -182,28 +180,28 @@ mod tests { } #[test] - fn resolve_base_class_by_name() -> std::io::Result<()> { + fn resolve_base_class_by_name() -> anyhow::Result<()> { let case = create_test()?; let db = &case.db; let path = case.src.path().join("mod.py"); std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?; let file = db - .resolve_module(ModuleName::new("mod")) + .resolve_module(ModuleName::new("mod"))? .expect("module should be found") - .path(db) + .path(db)? .file(); - let syms = db.symbol_table(file); + let syms = db.symbol_table(file)?; let sym = syms .root_symbol_id_by_name("Sub") .expect("Sub symbol should be found"); - let ty = db.infer_symbol_type(file, sym); + let ty = db.infer_symbol_type(file, sym)?; let Type::Class(class_id) = ty else { panic!("Sub is not a Class") }; - let jar = HasJar::::jar(db); + let jar = HasJar::::jar(db)?; let base_names: Vec<_> = jar .type_store .get_class(class_id) diff --git a/crates/ruff_server/Cargo.toml b/crates/ruff_server/Cargo.toml index 591a37b315..a93430d6eb 100644 --- a/crates/ruff_server/Cargo.toml +++ b/crates/ruff_server/Cargo.toml @@ -26,7 +26,7 @@ ruff_text_size = { path = "../ruff_text_size" } ruff_workspace = { path = "../ruff_workspace" } anyhow = { workspace = true } -crossbeam-channel = { workspace = true } +crossbeam = { workspace = true } jod-thread = { workspace = true } libc = { workspace = true } lsp-server = { workspace = true } diff --git a/crates/ruff_server/src/server/client.rs b/crates/ruff_server/src/server/client.rs index dae8ed269a..d36c50ef66 100644 --- a/crates/ruff_server/src/server/client.rs +++ b/crates/ruff_server/src/server/client.rs @@ -6,7 +6,7 @@ use serde_json::Value; use super::schedule::Task; -pub(crate) type ClientSender = crossbeam_channel::Sender; +pub(crate) type ClientSender = crossbeam::channel::Sender; type ResponseBuilder<'s> = Box Task<'s>>; diff --git a/crates/ruff_server/src/server/schedule.rs b/crates/ruff_server/src/server/schedule.rs index 4ffd819e86..fe8cc5c18c 100644 --- a/crates/ruff_server/src/server/schedule.rs +++ b/crates/ruff_server/src/server/schedule.rs @@ -1,6 +1,6 @@ use std::num::NonZeroUsize; -use crossbeam_channel::Sender; +use crossbeam::channel::Sender; use crate::session::Session; diff --git a/crates/ruff_server/src/server/schedule/thread/pool.rs b/crates/ruff_server/src/server/schedule/thread/pool.rs index ea07db65d4..7d1f9a418f 100644 --- a/crates/ruff_server/src/server/schedule/thread/pool.rs +++ b/crates/ruff_server/src/server/schedule/thread/pool.rs @@ -21,7 +21,7 @@ use std::{ }, }; -use crossbeam_channel::{Receiver, Sender}; +use crossbeam::channel::{Receiver, Sender}; use super::{Builder, JoinHandle, ThreadPriority}; @@ -52,7 +52,7 @@ impl Pool { let threads = usize::from(threads); // Channel buffer capacity is between 2 and 4, depending on the pool size. - let (job_sender, job_receiver) = crossbeam_channel::bounded(std::cmp::min(threads * 2, 4)); + let (job_sender, job_receiver) = crossbeam::channel::bounded(std::cmp::min(threads * 2, 4)); let extant_tasks = Arc::new(AtomicUsize::new(0)); let mut handles = Vec::with_capacity(threads);