[red-knot] Add "cheap" `program.snapshot` (#11172)

This commit is contained in:
Micha Reiser 2024-04-30 09:13:26 +02:00 committed by GitHub
parent eb6f562419
commit bc03d376e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 833 additions and 508 deletions

26
Cargo.lock generated
View File

@ -501,6 +501,19 @@ dependencies = [
"itertools 0.10.5", "itertools 0.10.5",
] ]
[[package]]
name = "crossbeam"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-channel" name = "crossbeam-channel"
version = "0.5.12" version = "0.5.12"
@ -529,6 +542,15 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "crossbeam-queue"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.19" version = "0.8.19"
@ -1804,7 +1826,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bitflags 2.5.0", "bitflags 2.5.0",
"crossbeam-channel", "crossbeam",
"ctrlc", "ctrlc",
"dashmap", "dashmap",
"hashbrown 0.14.5", "hashbrown 0.14.5",
@ -2341,7 +2363,7 @@ name = "ruff_server"
version = "0.2.2" version = "0.2.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"crossbeam-channel", "crossbeam",
"insta", "insta",
"jod-thread", "jod-thread",
"libc", "libc",

View File

@ -30,7 +30,7 @@ console_error_panic_hook = { version = "0.1.7" }
console_log = { version = "1.0.0" } console_log = { version = "1.0.0" }
countme = { version = "3.0.1" } countme = { version = "3.0.1" }
criterion = { version = "0.5.1", default-features = false } criterion = { version = "0.5.1", default-features = false }
crossbeam-channel = { version = "0.5.12" } crossbeam = { version = "0.8.4" }
dashmap = { version = "5.5.3" } dashmap = { version = "5.5.3" }
dirs = { version = "5.0.0" } dirs = { version = "5.0.0" }
drop_bomb = { version = "0.1.5" } drop_bomb = { version = "0.1.5" }

View File

@ -22,7 +22,7 @@ ruff_notebook = { path = "../ruff_notebook" }
anyhow = { workspace = true } anyhow = { workspace = true }
bitflags = { workspace = true } bitflags = { workspace = true }
ctrlc = "3.4.4" ctrlc = "3.4.4"
crossbeam-channel = { workspace = true } crossbeam = { workspace = true }
dashmap = { workspace = true } dashmap = { workspace = true }
hashbrown = { workspace = true } hashbrown = { workspace = true }
indexmap = { workspace = true } indexmap = { workspace = true }

View File

@ -2,6 +2,7 @@ use std::fmt::Formatter;
use std::hash::Hash; use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use crate::db::QueryResult;
use dashmap::mapref::entry::Entry; use dashmap::mapref::entry::Entry;
use crate::FxDashMap; use crate::FxDashMap;
@ -27,11 +28,11 @@ where
} }
} }
pub fn get<F>(&self, key: &K, compute: F) -> V pub fn get<F>(&self, key: &K, compute: F) -> QueryResult<V>
where where
F: FnOnce(&K) -> V, F: FnOnce(&K) -> QueryResult<V>,
{ {
match self.map.entry(key.clone()) { Ok(match self.map.entry(key.clone()) {
Entry::Occupied(cached) => { Entry::Occupied(cached) => {
self.statistics.hit(); self.statistics.hit();
@ -40,11 +41,11 @@ where
Entry::Vacant(vacant) => { Entry::Vacant(vacant) => {
self.statistics.miss(); self.statistics.miss();
let value = compute(key); let value = compute(key)?;
vacant.insert(value.clone()); vacant.insert(value.clone());
value value
} }
} })
} }
pub fn set(&mut self, key: K, value: V) { pub fn set(&mut self, key: K, value: V) {

View File

@ -1,35 +1,25 @@
use std::sync::{Arc, Condvar, Mutex}; use std::sync::atomic::AtomicBool;
use std::sync::Arc;
#[derive(Debug, Default)] #[derive(Debug, Clone, Default)]
pub struct CancellationTokenSource { pub struct CancellationTokenSource {
signal: Arc<(Mutex<bool>, Condvar)>, signal: Arc<AtomicBool>,
} }
impl CancellationTokenSource { impl CancellationTokenSource {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
signal: Arc::new((Mutex::new(false), Condvar::default())), signal: Arc::new(AtomicBool::new(false)),
} }
} }
#[tracing::instrument(level = "trace", skip_all)] #[tracing::instrument(level = "trace", skip_all)]
pub fn cancel(&self) { pub fn cancel(&self) {
let (cancelled, condvar) = &*self.signal; self.signal.store(true, std::sync::atomic::Ordering::SeqCst);
let mut cancelled = cancelled.lock().unwrap();
if *cancelled {
return;
}
*cancelled = true;
condvar.notify_all();
} }
pub fn is_cancelled(&self) -> bool { pub fn is_cancelled(&self) -> bool {
let (cancelled, _) = &*self.signal; self.signal.load(std::sync::atomic::Ordering::SeqCst)
*cancelled.lock().unwrap()
} }
pub fn token(&self) -> CancellationToken { pub fn token(&self) -> CancellationToken {
@ -41,26 +31,12 @@ impl CancellationTokenSource {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct CancellationToken { pub struct CancellationToken {
signal: Arc<(Mutex<bool>, Condvar)>, signal: Arc<AtomicBool>,
} }
impl CancellationToken { impl CancellationToken {
/// Returns `true` if cancellation has been requested. /// Returns `true` if cancellation has been requested.
pub fn is_cancelled(&self) -> bool { pub fn is_cancelled(&self) -> bool {
let (cancelled, _) = &*self.signal; self.signal.load(std::sync::atomic::Ordering::SeqCst)
*cancelled.lock().unwrap()
}
pub fn wait(&self) {
let (bool, condvar) = &*self.signal;
let lock = condvar
.wait_while(bool.lock().unwrap(), |bool| !*bool)
.unwrap();
debug_assert!(*lock);
drop(lock);
} }
} }

View File

@ -1,3 +1,8 @@
mod jars;
mod query;
mod runtime;
mod storage;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
@ -9,32 +14,115 @@ use crate::source::{Source, SourceStorage};
use crate::symbols::{SymbolId, SymbolTable, SymbolTablesStorage}; use crate::symbols::{SymbolId, SymbolTable, SymbolTablesStorage};
use crate::types::{Type, TypeStore}; use crate::types::{Type, TypeStore};
pub trait SourceDb { pub use jars::{HasJar, HasJars};
pub use query::{QueryError, QueryResult};
pub use runtime::DbRuntime;
pub use storage::JarsStorage;
pub trait Database {
/// Returns a reference to the runtime of the current worker.
fn runtime(&self) -> &DbRuntime;
/// Returns a mutable reference to the runtime. Only one worker can hold a mutable reference to the runtime.
fn runtime_mut(&mut self) -> &mut DbRuntime;
/// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise.
fn cancelled(&self) -> QueryResult<()> {
self.runtime().cancelled()
}
/// Returns `true` if the queries have been cancelled.
fn is_cancelled(&self) -> bool {
self.runtime().is_cancelled()
}
}
/// Database that supports running queries from multiple threads.
pub trait ParallelDatabase: Database + Send {
/// Creates a snapshot of the database state that can be used to query the database in another thread.
///
/// The snapshot is a read-only view of the database but query results are shared between threads.
/// All queries will be automatically cancelled when applying any mutations (calling [`HasJars::jars_mut`])
/// to the database (not the snapshot, because they're readonly).
///
/// ## Creating a snapshot
///
/// Creating a snapshot of the database's jars is cheap but creating a snapshot of
/// other state stored on the database might require deep-cloning data. That's why you should
/// avoid creating snapshots in a hot function (e.g. don't create a snapshot for each file, instead
/// create a snapshot when scheduling the check of an entire program).
///
/// ## Salsa compatibility
/// Salsa prohibits creating a snapshot while running a local query (it's fine if other workers run a query) [[source](https://github.com/salsa-rs/salsa/issues/80)].
/// We should avoid creating snapshots while running a query because we might want to adopt Salsa in the future (if we can figure out persistent caching).
/// Unfortunately, the infrastructure doesn't provide an automated way of knowing when a query is run, that's
/// why we have to "enforce" this constraint manually.
fn snapshot(&self) -> Snapshot<Self>;
}
/// Readonly snapshot of a database.
///
/// ## Dead locks
/// A snapshot should always be dropped as soon as it is no longer necessary to run queries.
/// Storing the snapshot without running a query or periodically checking if cancellation was requested
/// can lead to deadlocks because mutating the [`Database`] requires cancels all pending queries
/// and waiting for all [`Snapshot`]s to be dropped.
#[derive(Debug)]
pub struct Snapshot<DB: ?Sized>
where
DB: ParallelDatabase,
{
db: DB,
}
impl<DB> Snapshot<DB>
where
DB: ParallelDatabase,
{
pub fn new(db: DB) -> Self {
Snapshot { db }
}
}
impl<DB> std::ops::Deref for Snapshot<DB>
where
DB: ParallelDatabase,
{
type Target = DB;
fn deref(&self) -> &DB {
&self.db
}
}
// Red knot specific databases code.
pub trait SourceDb: Database {
// queries // queries
fn file_id(&self, path: &std::path::Path) -> FileId; fn file_id(&self, path: &std::path::Path) -> FileId;
fn file_path(&self, file_id: FileId) -> Arc<std::path::Path>; fn file_path(&self, file_id: FileId) -> Arc<std::path::Path>;
fn source(&self, file_id: FileId) -> Source; fn source(&self, file_id: FileId) -> QueryResult<Source>;
fn parse(&self, file_id: FileId) -> Parsed; fn parse(&self, file_id: FileId) -> QueryResult<Parsed>;
fn lint_syntax(&self, file_id: FileId) -> Diagnostics; fn lint_syntax(&self, file_id: FileId) -> QueryResult<Diagnostics>;
} }
pub trait SemanticDb: SourceDb { pub trait SemanticDb: SourceDb {
// queries // queries
fn resolve_module(&self, name: ModuleName) -> Option<Module>; fn resolve_module(&self, name: ModuleName) -> QueryResult<Option<Module>>;
fn file_to_module(&self, file_id: FileId) -> Option<Module>; fn file_to_module(&self, file_id: FileId) -> QueryResult<Option<Module>>;
fn path_to_module(&self, path: &Path) -> Option<Module>; fn path_to_module(&self, path: &Path) -> QueryResult<Option<Module>>;
fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable>; fn symbol_table(&self, file_id: FileId) -> QueryResult<Arc<SymbolTable>>;
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type; fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type>;
fn lint_semantic(&self, file_id: FileId) -> Diagnostics; fn lint_semantic(&self, file_id: FileId) -> QueryResult<Diagnostics>;
// mutations // mutations
@ -60,32 +148,15 @@ pub struct SemanticJar {
pub lint_semantic: LintSemanticStorage, pub lint_semantic: LintSemanticStorage,
} }
/// Gives access to a specific jar in the database.
///
/// Nope, the terminology isn't borrowed from Java but from Salsa <https://salsa-rs.github.io/salsa/>,
/// which is an analogy to storing the salsa in different jars.
///
/// The basic idea is that each crate can define its own jar and the jars can be combined to a single
/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of
/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels.
///
/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache).
/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide
/// to use a macro to generate the queries.
pub trait HasJar<T> {
/// Gives a read-only reference to the jar.
fn jar(&self) -> &T;
/// Gives a mutable reference to the jar.
fn jar_mut(&mut self) -> &mut T;
}
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use crate::db::{HasJar, SourceDb, SourceJar}; use crate::db::{
Database, DbRuntime, HasJar, HasJars, JarsStorage, ParallelDatabase, QueryResult, Snapshot,
SourceDb, SourceJar,
};
use crate::files::{FileId, Files}; use crate::files::{FileId, Files};
use crate::lint::{lint_semantic, lint_syntax, Diagnostics}; use crate::lint::{lint_semantic, lint_syntax, Diagnostics};
use crate::module::{ use crate::module::{
@ -104,27 +175,26 @@ pub(crate) mod tests {
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct TestDb { pub(crate) struct TestDb {
files: Files, files: Files,
source: SourceJar, jars: JarsStorage<Self>,
semantic: SemanticJar,
} }
impl HasJar<SourceJar> for TestDb { impl HasJar<SourceJar> for TestDb {
fn jar(&self) -> &SourceJar { fn jar(&self) -> QueryResult<&SourceJar> {
&self.source Ok(&self.jars()?.0)
} }
fn jar_mut(&mut self) -> &mut SourceJar { fn jar_mut(&mut self) -> &mut SourceJar {
&mut self.source &mut self.jars_mut().0
} }
} }
impl HasJar<SemanticJar> for TestDb { impl HasJar<SemanticJar> for TestDb {
fn jar(&self) -> &SemanticJar { fn jar(&self) -> QueryResult<&SemanticJar> {
&self.semantic Ok(&self.jars()?.1)
} }
fn jar_mut(&mut self) -> &mut SemanticJar { fn jar_mut(&mut self) -> &mut SemanticJar {
&mut self.semantic &mut self.jars_mut().1
} }
} }
@ -137,41 +207,41 @@ pub(crate) mod tests {
self.files.path(file_id) self.files.path(file_id)
} }
fn source(&self, file_id: FileId) -> Source { fn source(&self, file_id: FileId) -> QueryResult<Source> {
source_text(self, file_id) source_text(self, file_id)
} }
fn parse(&self, file_id: FileId) -> Parsed { fn parse(&self, file_id: FileId) -> QueryResult<Parsed> {
parse(self, file_id) parse(self, file_id)
} }
fn lint_syntax(&self, file_id: FileId) -> Diagnostics { fn lint_syntax(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_syntax(self, file_id) lint_syntax(self, file_id)
} }
} }
impl SemanticDb for TestDb { impl SemanticDb for TestDb {
fn resolve_module(&self, name: ModuleName) -> Option<Module> { fn resolve_module(&self, name: ModuleName) -> QueryResult<Option<Module>> {
resolve_module(self, name) resolve_module(self, name)
} }
fn file_to_module(&self, file_id: FileId) -> Option<Module> { fn file_to_module(&self, file_id: FileId) -> QueryResult<Option<Module>> {
file_to_module(self, file_id) file_to_module(self, file_id)
} }
fn path_to_module(&self, path: &Path) -> Option<Module> { fn path_to_module(&self, path: &Path) -> QueryResult<Option<Module>> {
path_to_module(self, path) path_to_module(self, path)
} }
fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable> { fn symbol_table(&self, file_id: FileId) -> QueryResult<Arc<SymbolTable>> {
symbol_table(self, file_id) symbol_table(self, file_id)
} }
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type { fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(self, file_id, symbol_id) infer_symbol_type(self, file_id, symbol_id)
} }
fn lint_semantic(&self, file_id: FileId) -> Diagnostics { fn lint_semantic(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_semantic(self, file_id) lint_semantic(self, file_id)
} }
@ -183,4 +253,35 @@ pub(crate) mod tests {
set_module_search_paths(self, paths); set_module_search_paths(self, paths);
} }
} }
impl HasJars for TestDb {
type Jars = (SourceJar, SemanticJar);
fn jars(&self) -> QueryResult<&Self::Jars> {
self.jars.jars()
}
fn jars_mut(&mut self) -> &mut Self::Jars {
self.jars.jars_mut()
}
}
impl Database for TestDb {
fn runtime(&self) -> &DbRuntime {
self.jars.runtime()
}
fn runtime_mut(&mut self) -> &mut DbRuntime {
self.jars.runtime_mut()
}
}
impl ParallelDatabase for TestDb {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(Self {
files: self.files.clone(),
jars: self.jars.snapshot(),
})
}
}
} }

View File

@ -0,0 +1,37 @@
use crate::db::query::QueryResult;
/// Gives access to a specific jar in the database.
///
/// Nope, the terminology isn't borrowed from Java but from Salsa <https://salsa-rs.github.io/salsa/>,
/// which is an analogy to storing the salsa in different jars.
///
/// The basic idea is that each crate can define its own jar and the jars can be combined to a single
/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of
/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels.
///
/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache).
/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide
/// to use a macro to generate the queries.
pub trait HasJar<T> {
/// Gives a read-only reference to the jar.
fn jar(&self) -> QueryResult<&T>;
/// Gives a mutable reference to the jar.
fn jar_mut(&mut self) -> &mut T;
}
/// Gives access to the jars in a database.
pub trait HasJars {
/// A type storing the jars.
///
/// Most commonly, this is a tuple where each jar is a tuple element.
type Jars: Default;
/// Gives access to the underlying jars but tests if the queries have been cancelled.
///
/// Returns `Err(QueryError::Cancelled)` if the queries have been cancelled.
fn jars(&self) -> QueryResult<&Self::Jars>;
/// Gives mutable access to the underlying jars.
fn jars_mut(&mut self) -> &mut Self::Jars;
}

View File

@ -0,0 +1,20 @@
use std::fmt::{Display, Formatter};
/// Reason why a db query operation failed.
#[derive(Debug, Clone, Copy)]
pub enum QueryError {
/// The query was cancelled because the DB was mutated or the query was cancelled by the host (e.g. on a file change or when pressing CTRL+C).
Cancelled,
}
impl Display for QueryError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
QueryError::Cancelled => f.write_str("query was cancelled"),
}
}
}
impl std::error::Error for QueryError {}
pub type QueryResult<T> = Result<T, QueryError>;

View File

@ -0,0 +1,41 @@
use crate::cancellation::CancellationTokenSource;
use crate::db::{QueryError, QueryResult};
/// Holds the jar agnostic state of the database.
#[derive(Debug, Default)]
pub struct DbRuntime {
/// The cancellation token source used to signal other works that the queries should be aborted and
/// exit at the next possible point.
cancellation_token: CancellationTokenSource,
}
impl DbRuntime {
pub(super) fn snapshot(&self) -> Self {
Self {
cancellation_token: self.cancellation_token.clone(),
}
}
/// Cancels the pending queries of other workers. The current worker cannot have any pending
/// queries because we're holding a mutable reference to the runtime.
pub(super) fn cancel_other_workers(&mut self) {
self.cancellation_token.cancel();
// Set a new cancellation token so that we're in a non-cancelled state again when running the next
// query.
self.cancellation_token = CancellationTokenSource::default();
}
/// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise.
pub(super) fn cancelled(&self) -> QueryResult<()> {
if self.cancellation_token.is_cancelled() {
Err(QueryError::Cancelled)
} else {
Ok(())
}
}
/// Returns `true` if the queries have been cancelled.
pub(super) fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
}

View File

@ -0,0 +1,117 @@
use std::fmt::Formatter;
use std::sync::Arc;
use crossbeam::sync::WaitGroup;
use crate::db::query::QueryResult;
use crate::db::runtime::DbRuntime;
use crate::db::{HasJars, ParallelDatabase};
/// Stores the jars of a database and the state for each worker.
///
/// Today, all state is shared across all workers, but it may be desired to store data per worker in the future.
pub struct JarsStorage<T>
where
T: HasJars + Sized,
{
// It's important that `jars_wait_group` is declared after `jars` to ensure that `jars` is dropped first.
// See https://doc.rust-lang.org/reference/destructors.html
/// Stores the jars of the database.
jars: Arc<T::Jars>,
/// Used to count the references to `jars`. Allows implementing `jars_mut` without requiring to clone `jars`.
jars_wait_group: WaitGroup,
/// The data agnostic state.
runtime: DbRuntime,
}
impl<Db> JarsStorage<Db>
where
Db: HasJars,
{
pub(super) fn new() -> Self {
Self {
jars: Arc::new(Db::Jars::default()),
jars_wait_group: WaitGroup::default(),
runtime: DbRuntime::default(),
}
}
/// Creates a snapshot of the jars.
///
/// Creating the snapshot is cheap because it doesn't clone the jars, it only increments a ref counter.
#[must_use]
pub fn snapshot(&self) -> JarsStorage<Db>
where
Db: ParallelDatabase,
{
Self {
jars: self.jars.clone(),
jars_wait_group: self.jars_wait_group.clone(),
runtime: self.runtime.snapshot(),
}
}
pub(crate) fn jars(&self) -> QueryResult<&Db::Jars> {
self.runtime.cancelled()?;
Ok(&self.jars)
}
/// Returns a mutable reference to the jars without cloning their content.
///
/// The method cancels any pending queries of other works and waits for them to complete so that
/// this instance is the only instance holding a reference to the jars.
pub(crate) fn jars_mut(&mut self) -> &mut Db::Jars {
// We have a mutable ref here, so no more workers can be spawned between calling this function and taking the mut ref below.
self.cancel_other_workers();
// Now all other references to `self.jars` should have been released. We can now safely return a mutable reference
// to the Arc's content.
let jars =
Arc::get_mut(&mut self.jars).expect("All references to jars should have been released");
jars
}
pub(crate) fn runtime(&self) -> &DbRuntime {
&self.runtime
}
pub(crate) fn runtime_mut(&mut self) -> &mut DbRuntime {
// Note: This method may need to use a similar trick to `jars_mut` if `DbRuntime` is ever to store data that is shared between workers.
&mut self.runtime
}
#[tracing::instrument(level = "trace", skip(self))]
fn cancel_other_workers(&mut self) {
self.runtime.cancel_other_workers();
// Wait for all other works to complete.
let existing_wait = std::mem::take(&mut self.jars_wait_group);
existing_wait.wait();
}
}
impl<Db> Default for JarsStorage<Db>
where
Db: HasJars,
{
fn default() -> Self {
Self::new()
}
}
impl<T> std::fmt::Debug for JarsStorage<T>
where
T: HasJars,
<T as HasJars>::Jars: std::fmt::Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedStorage")
.field("jars", &self.jars)
.field("jars_wait_group", &self.jars_wait_group)
.field("runtime", &self.runtime)
.finish()
}
}

View File

@ -27,7 +27,7 @@ pub(crate) type FxDashMap<K, V> = dashmap::DashMap<K, V, BuildHasherDefault<FxHa
pub(crate) type FxDashSet<V> = dashmap::DashSet<V, BuildHasherDefault<FxHasher>>; pub(crate) type FxDashSet<V> = dashmap::DashSet<V, BuildHasherDefault<FxHasher>>;
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>; pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Workspace { pub struct Workspace {
/// TODO this should be a resolved path. We should probably use a newtype wrapper that guarantees that /// TODO this should be a resolved path. We should probably use a newtype wrapper that guarantees that
/// PATH is a UTF-8 path and is normalized. /// PATH is a UTF-8 path and is normalized.

View File

@ -1,12 +1,15 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use ruff_python_ast::visitor::Visitor; use ruff_python_ast::visitor::Visitor;
use ruff_python_ast::{ModModule, StringLiteral}; use ruff_python_ast::{ModModule, StringLiteral};
use crate::cache::KeyValueCache; use crate::cache::KeyValueCache;
use crate::db::{HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar}; use crate::db::{
HasJar, ParallelDatabase, QueryResult, SemanticDb, SemanticJar, SourceDb, SourceJar,
};
use crate::files::FileId; use crate::files::FileId;
use crate::parse::Parsed; use crate::parse::Parsed;
use crate::source::Source; use crate::source::Source;
@ -14,19 +17,28 @@ use crate::symbols::{Definition, SymbolId, SymbolTable};
use crate::types::Type; use crate::types::Type;
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_syntax<Db>(db: &Db, file_id: FileId) -> Diagnostics pub(crate) fn lint_syntax<Db>(db: &Db, file_id: FileId) -> QueryResult<Diagnostics>
where where
Db: SourceDb + HasJar<SourceJar>, Db: SourceDb + HasJar<SourceJar> + ParallelDatabase,
{ {
let storage = &db.jar().lint_syntax; let storage = &db.jar()?.lint_syntax;
#[allow(clippy::print_stdout)]
if std::env::var("RED_KNOT_SLOW_LINT").is_ok() {
for i in 0..10 {
db.cancelled()?;
println!("RED_KNOT_SLOW_LINT is set, sleeping for {i}/10 seconds");
std::thread::sleep(Duration::from_secs(1));
}
}
storage.get(&file_id, |file_id| { storage.get(&file_id, |file_id| {
let mut diagnostics = Vec::new(); let mut diagnostics = Vec::new();
let source = db.source(*file_id); let source = db.source(*file_id)?;
lint_lines(source.text(), &mut diagnostics); lint_lines(source.text(), &mut diagnostics);
let parsed = db.parse(*file_id); let parsed = db.parse(*file_id)?;
if parsed.errors().is_empty() { if parsed.errors().is_empty() {
let ast = parsed.ast(); let ast = parsed.ast();
@ -41,7 +53,7 @@ where
diagnostics.extend(parsed.errors().iter().map(std::string::ToString::to_string)); diagnostics.extend(parsed.errors().iter().map(std::string::ToString::to_string));
} }
Diagnostics::from(diagnostics) Ok(Diagnostics::from(diagnostics))
}) })
} }
@ -63,16 +75,16 @@ fn lint_lines(source: &str, diagnostics: &mut Vec<String>) {
} }
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_semantic<Db>(db: &Db, file_id: FileId) -> Diagnostics pub(crate) fn lint_semantic<Db>(db: &Db, file_id: FileId) -> QueryResult<Diagnostics>
where where
Db: SemanticDb + HasJar<SemanticJar>, Db: SemanticDb + HasJar<SemanticJar>,
{ {
let storage = &db.jar().lint_semantic; let storage = &db.jar()?.lint_semantic;
storage.get(&file_id, |file_id| { storage.get(&file_id, |file_id| {
let source = db.source(*file_id); let source = db.source(*file_id)?;
let parsed = db.parse(*file_id); let parsed = db.parse(*file_id)?;
let symbols = db.symbol_table(*file_id); let symbols = db.symbol_table(*file_id)?;
let context = SemanticLintContext { let context = SemanticLintContext {
file_id: *file_id, file_id: *file_id,
@ -83,25 +95,25 @@ where
diagnostics: RefCell::new(Vec::new()), diagnostics: RefCell::new(Vec::new()),
}; };
lint_unresolved_imports(&context); lint_unresolved_imports(&context)?;
Diagnostics::from(context.diagnostics.take()) Ok(Diagnostics::from(context.diagnostics.take()))
}) })
} }
fn lint_unresolved_imports(context: &SemanticLintContext) { fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> {
// TODO: Consider iterating over the dependencies (imports) only instead of all definitions. // TODO: Consider iterating over the dependencies (imports) only instead of all definitions.
for (symbol, definition) in context.symbols().all_definitions() { for (symbol, definition) in context.symbols().all_definitions() {
match definition { match definition {
Definition::Import(import) => { Definition::Import(import) => {
let ty = context.eval_symbol(symbol); let ty = context.infer_symbol_type(symbol)?;
if ty.is_unknown() { if ty.is_unknown() {
context.push_diagnostic(format!("Unresolved module {}", import.module)); context.push_diagnostic(format!("Unresolved module {}", import.module));
} }
} }
Definition::ImportFrom(import) => { Definition::ImportFrom(import) => {
let ty = context.eval_symbol(symbol); let ty = context.infer_symbol_type(symbol)?;
if ty.is_unknown() { if ty.is_unknown() {
let module_name = import.module().map(Deref::deref).unwrap_or_default(); let module_name = import.module().map(Deref::deref).unwrap_or_default();
@ -126,6 +138,8 @@ fn lint_unresolved_imports(context: &SemanticLintContext) {
_ => {} _ => {}
} }
} }
Ok(())
} }
pub struct SemanticLintContext<'a> { pub struct SemanticLintContext<'a> {
@ -154,7 +168,7 @@ impl<'a> SemanticLintContext<'a> {
&self.symbols &self.symbols
} }
pub fn eval_symbol(&self, symbol_id: SymbolId) -> Type { pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
self.db.infer_symbol_type(self.file_id, symbol_id) self.db.infer_symbol_type(self.file_id, symbol_id)
} }

View File

@ -4,6 +4,7 @@ use std::collections::hash_map::Entry;
use std::path::Path; use std::path::Path;
use std::sync::Mutex; use std::sync::Mutex;
use crossbeam::channel as crossbeam_channel;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use tracing::subscriber::Interest; use tracing::subscriber::Interest;
use tracing::{Level, Metadata}; use tracing::{Level, Metadata};
@ -12,11 +13,10 @@ use tracing_subscriber::layer::{Context, Filter, SubscriberExt};
use tracing_subscriber::{Layer, Registry}; use tracing_subscriber::{Layer, Registry};
use tracing_tree::time::Uptime; use tracing_tree::time::Uptime;
use red_knot::cancellation::CancellationTokenSource; use red_knot::db::{HasJar, ParallelDatabase, QueryError, SemanticDb, SourceDb, SourceJar};
use red_knot::db::{HasJar, SourceDb, SourceJar};
use red_knot::files::FileId; use red_knot::files::FileId;
use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind}; use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind};
use red_knot::program::check::{CheckError, RayonCheckScheduler}; use red_knot::program::check::RayonCheckScheduler;
use red_knot::program::{FileChange, FileChangeKind, Program}; use red_knot::program::{FileChange, FileChangeKind, Program};
use red_knot::watch::FileWatcher; use red_knot::watch::FileWatcher;
use red_knot::Workspace; use red_knot::Workspace;
@ -51,7 +51,8 @@ fn main() -> anyhow::Result<()> {
workspace.root().to_path_buf(), workspace.root().to_path_buf(),
ModuleSearchPathKind::FirstParty, ModuleSearchPathKind::FirstParty,
); );
let mut program = Program::new(workspace, vec![workspace_search_path]); let mut program = Program::new(workspace);
program.set_module_search_paths(vec![workspace_search_path]);
let entry_id = program.file_id(entry_point); let entry_id = program.file_id(entry_point);
program.workspace_mut().open_file(entry_id); program.workspace_mut().open_file(entry_id);
@ -82,7 +83,7 @@ fn main() -> anyhow::Result<()> {
main_loop.run(&mut program); main_loop.run(&mut program);
let source_jar: &SourceJar = program.jar(); let source_jar: &SourceJar = program.jar().unwrap();
dbg!(source_jar.parsed.statistics()); dbg!(source_jar.parsed.statistics());
dbg!(source_jar.sources.statistics()); dbg!(source_jar.sources.statistics());
@ -101,10 +102,9 @@ impl MainLoop {
let (main_loop_sender, main_loop_receiver) = crossbeam_channel::bounded(1); let (main_loop_sender, main_loop_receiver) = crossbeam_channel::bounded(1);
let mut orchestrator = Orchestrator { let mut orchestrator = Orchestrator {
pending_analysis: None,
receiver: orchestrator_receiver, receiver: orchestrator_receiver,
sender: main_loop_sender.clone(), sender: main_loop_sender.clone(),
aggregated_changes: AggregatedChanges::default(), revision: 0,
}; };
std::thread::spawn(move || { std::thread::spawn(move || {
@ -137,34 +137,32 @@ impl MainLoop {
tracing::trace!("Main Loop: Tick"); tracing::trace!("Main Loop: Tick");
match message { match message {
MainLoopMessage::CheckProgram => { MainLoopMessage::CheckProgram { revision } => {
// Remove mutability from program. let program = program.snapshot();
let program = &*program; let sender = self.orchestrator_sender.clone();
let run_cancellation_token_source = CancellationTokenSource::new();
let run_cancellation_token = run_cancellation_token_source.token();
let sender = &self.orchestrator_sender;
// Spawn a new task that checks the program. This needs to be done in a separate thread
// to prevent blocking the main loop here.
rayon::spawn(move || {
rayon::in_place_scope(|scope| {
let scheduler = RayonCheckScheduler::new(&program, scope);
match program.check(&scheduler) {
Ok(result) => {
sender sender
.send(OrchestratorMessage::CheckProgramStarted { .send(OrchestratorMessage::CheckProgramCompleted {
cancellation_token: run_cancellation_token_source, diagnostics: result,
revision,
}) })
.unwrap(); .unwrap();
rayon::in_place_scope(|scope| {
let scheduler = RayonCheckScheduler::new(program, scope);
let result = program.check(&scheduler, run_cancellation_token);
match result {
Ok(result) => sender
.send(OrchestratorMessage::CheckProgramCompleted(result))
.unwrap(),
Err(CheckError::Cancelled) => sender
.send(OrchestratorMessage::CheckProgramCancelled)
.unwrap(),
} }
Err(QueryError::Cancelled) => {}
}
});
}); });
} }
MainLoopMessage::ApplyChanges(changes) => { MainLoopMessage::ApplyChanges(changes) => {
// Automatically cancels any pending queries and waits for them to complete.
program.apply_changes(changes.iter()); program.apply_changes(changes.iter());
} }
MainLoopMessage::CheckCompleted(diagnostics) => { MainLoopMessage::CheckCompleted(diagnostics) => {
@ -211,13 +209,11 @@ impl MainLoopCancellationToken {
} }
struct Orchestrator { struct Orchestrator {
aggregated_changes: AggregatedChanges,
pending_analysis: Option<PendingAnalysisState>,
/// Sends messages to the main loop. /// Sends messages to the main loop.
sender: crossbeam_channel::Sender<MainLoopMessage>, sender: crossbeam_channel::Sender<MainLoopMessage>,
/// Receives messages from the main loop. /// Receives messages from the main loop.
receiver: crossbeam_channel::Receiver<OrchestratorMessage>, receiver: crossbeam_channel::Receiver<OrchestratorMessage>,
revision: usize,
} }
impl Orchestrator { impl Orchestrator {
@ -225,51 +221,33 @@ impl Orchestrator {
while let Ok(message) = self.receiver.recv() { while let Ok(message) = self.receiver.recv() {
match message { match message {
OrchestratorMessage::Run => { OrchestratorMessage::Run => {
self.pending_analysis = None;
self.sender.send(MainLoopMessage::CheckProgram).unwrap();
}
OrchestratorMessage::CheckProgramStarted { cancellation_token } => {
debug_assert!(self.pending_analysis.is_none());
self.pending_analysis = Some(PendingAnalysisState { cancellation_token });
}
OrchestratorMessage::CheckProgramCompleted(diagnostics) => {
self.pending_analysis
.take()
.expect("Expected a pending analysis.");
self.sender self.sender
.send(MainLoopMessage::CheckCompleted(diagnostics)) .send(MainLoopMessage::CheckProgram {
revision: self.revision,
})
.unwrap(); .unwrap();
} }
OrchestratorMessage::CheckProgramCancelled => { OrchestratorMessage::CheckProgramCompleted {
self.pending_analysis diagnostics,
.take() revision,
.expect("Expected a pending analysis."); } => {
// Only take the diagnostics if they are for the latest revision.
self.debounce_changes(); if self.revision == revision {
self.sender
.send(MainLoopMessage::CheckCompleted(diagnostics))
.unwrap();
} else {
tracing::debug!("Discarding diagnostics for outdated revision {revision} (current: {}).", self.revision);
}
} }
OrchestratorMessage::FileChanges(changes) => { OrchestratorMessage::FileChanges(changes) => {
// Request cancellation, but wait until all analysis tasks have completed to // Request cancellation, but wait until all analysis tasks have completed to
// avoid stale messages in the next main loop. // avoid stale messages in the next main loop.
let pending = if let Some(pending_state) = self.pending_analysis.as_ref() {
pending_state.cancellation_token.cancel();
true
} else {
false
};
self.aggregated_changes.extend(changes); self.revision += 1;
self.debounce_changes(changes);
// If there are no pending analysis tasks, apply the file changes. Otherwise
// keep running until all file checks have completed.
if !pending {
self.debounce_changes();
}
} }
OrchestratorMessage::Shutdown => { OrchestratorMessage::Shutdown => {
return self.shutdown(); return self.shutdown();
@ -278,8 +256,9 @@ impl Orchestrator {
} }
} }
fn debounce_changes(&mut self) { fn debounce_changes(&self, changes: Vec<FileChange>) {
debug_assert!(self.pending_analysis.is_none()); let mut aggregated_changes = AggregatedChanges::default();
aggregated_changes.extend(changes);
loop { loop {
// Consume possibly incoming file change messages before running a new analysis, but don't wait for more than 100ms. // Consume possibly incoming file change messages before running a new analysis, but don't wait for more than 100ms.
@ -290,10 +269,12 @@ impl Orchestrator {
return self.shutdown(); return self.shutdown();
} }
Ok(OrchestratorMessage::FileChanges(file_changes)) => { Ok(OrchestratorMessage::FileChanges(file_changes)) => {
self.aggregated_changes.extend(file_changes); aggregated_changes.extend(file_changes);
} }
Ok(OrchestratorMessage::CheckProgramStarted {..}| OrchestratorMessage::CheckProgramCompleted(_) | OrchestratorMessage::CheckProgramCancelled) => unreachable!("No program check should be running while debouncing changes."), Ok(OrchestratorMessage::CheckProgramCompleted { .. })=> {
// disregard any outdated completion message.
}
Ok(OrchestratorMessage::Run) => unreachable!("The orchestrator is already running."), Ok(OrchestratorMessage::Run) => unreachable!("The orchestrator is already running."),
Err(_) => { Err(_) => {
@ -302,10 +283,10 @@ impl Orchestrator {
} }
} }
}, },
default(std::time::Duration::from_millis(100)) => { default(std::time::Duration::from_millis(10)) => {
// No more file changes after 100 ms, send the changes and schedule a new analysis // No more file changes after 10 ms, send the changes and schedule a new analysis
self.sender.send(MainLoopMessage::ApplyChanges(std::mem::take(&mut self.aggregated_changes))).unwrap(); self.sender.send(MainLoopMessage::ApplyChanges(aggregated_changes)).unwrap();
self.sender.send(MainLoopMessage::CheckProgram).unwrap(); self.sender.send(MainLoopMessage::CheckProgram { revision: self.revision}).unwrap();
return; return;
} }
} }
@ -318,15 +299,10 @@ impl Orchestrator {
} }
} }
#[derive(Debug)]
struct PendingAnalysisState {
cancellation_token: CancellationTokenSource,
}
/// Message sent from the orchestrator to the main loop. /// Message sent from the orchestrator to the main loop.
#[derive(Debug)] #[derive(Debug)]
enum MainLoopMessage { enum MainLoopMessage {
CheckProgram, CheckProgram { revision: usize },
CheckCompleted(Vec<String>), CheckCompleted(Vec<String>),
ApplyChanges(AggregatedChanges), ApplyChanges(AggregatedChanges),
Exit, Exit,
@ -337,11 +313,10 @@ enum OrchestratorMessage {
Run, Run,
Shutdown, Shutdown,
CheckProgramStarted { CheckProgramCompleted {
cancellation_token: CancellationTokenSource, diagnostics: Vec<String>,
revision: usize,
}, },
CheckProgramCompleted(Vec<String>),
CheckProgramCancelled,
FileChanges(Vec<FileChange>), FileChanges(Vec<FileChange>),
} }

View File

@ -7,7 +7,7 @@ use std::sync::Arc;
use dashmap::mapref::entry::Entry; use dashmap::mapref::entry::Entry;
use smol_str::SmolStr; use smol_str::SmolStr;
use crate::db::{HasJar, SemanticDb, SemanticJar}; use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId; use crate::files::FileId;
use crate::symbols::Dependency; use crate::symbols::Dependency;
use crate::FxDashMap; use crate::FxDashMap;
@ -17,44 +17,48 @@ use crate::FxDashMap;
pub struct Module(u32); pub struct Module(u32);
impl Module { impl Module {
pub fn name<Db>(&self, db: &Db) -> ModuleName pub fn name<Db>(&self, db: &Db) -> QueryResult<ModuleName>
where where
Db: HasJar<SemanticJar>, Db: HasJar<SemanticJar>,
{ {
let modules = &db.jar().module_resolver; let modules = &db.jar()?.module_resolver;
modules.modules.get(self).unwrap().name.clone() Ok(modules.modules.get(self).unwrap().name.clone())
} }
pub fn path<Db>(&self, db: &Db) -> ModulePath pub fn path<Db>(&self, db: &Db) -> QueryResult<ModulePath>
where where
Db: HasJar<SemanticJar>, Db: HasJar<SemanticJar>,
{ {
let modules = &db.jar().module_resolver; let modules = &db.jar()?.module_resolver;
modules.modules.get(self).unwrap().path.clone() Ok(modules.modules.get(self).unwrap().path.clone())
} }
pub fn kind<Db>(&self, db: &Db) -> ModuleKind pub fn kind<Db>(&self, db: &Db) -> QueryResult<ModuleKind>
where where
Db: HasJar<SemanticJar>, Db: HasJar<SemanticJar>,
{ {
let modules = &db.jar().module_resolver; let modules = &db.jar()?.module_resolver;
modules.modules.get(self).unwrap().kind Ok(modules.modules.get(self).unwrap().kind)
} }
pub fn resolve_dependency<Db>(&self, db: &Db, dependency: &Dependency) -> Option<ModuleName> pub fn resolve_dependency<Db>(
&self,
db: &Db,
dependency: &Dependency,
) -> QueryResult<Option<ModuleName>>
where where
Db: HasJar<SemanticJar>, Db: HasJar<SemanticJar>,
{ {
let (level, module) = match dependency { let (level, module) = match dependency {
Dependency::Module(module) => return Some(module.clone()), Dependency::Module(module) => return Ok(Some(module.clone())),
Dependency::Relative { level, module } => (*level, module.as_deref()), Dependency::Relative { level, module } => (*level, module.as_deref()),
}; };
let name = self.name(db); let name = self.name(db)?;
let kind = self.kind(db); let kind = self.kind(db)?;
let mut components = name.components().peekable(); let mut components = name.components().peekable();
@ -67,7 +71,9 @@ impl Module {
// Skip over the relative parts. // Skip over the relative parts.
for _ in start..level.get() { for _ in start..level.get() {
components.next_back()?; if components.next_back().is_none() {
return Ok(None);
}
} }
let mut name = String::new(); let mut name = String::new();
@ -80,11 +86,11 @@ impl Module {
name.push_str(part); name.push_str(part);
} }
if name.is_empty() { Ok(if name.is_empty() {
None None
} else { } else {
Some(ModuleName(SmolStr::new(name))) Some(ModuleName(SmolStr::new(name)))
} })
} }
} }
@ -238,20 +244,25 @@ pub struct ModuleData {
/// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query. /// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query.
/// For this to work with salsa, it would be necessary to intern all `ModuleName`s. /// For this to work with salsa, it would be necessary to intern all `ModuleName`s.
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub fn resolve_module<Db>(db: &Db, name: ModuleName) -> Option<Module> pub fn resolve_module<Db>(db: &Db, name: ModuleName) -> QueryResult<Option<Module>>
where where
Db: SemanticDb + HasJar<SemanticJar>, Db: SemanticDb + HasJar<SemanticJar>,
{ {
let jar = db.jar(); let jar = db.jar();
let modules = &jar.module_resolver; let modules = &jar?.module_resolver;
let entry = modules.by_name.entry(name.clone()); let entry = modules.by_name.entry(name.clone());
match entry { match entry {
Entry::Occupied(entry) => Some(*entry.get()), Entry::Occupied(entry) => Ok(Some(*entry.get())),
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
let (root_path, absolute_path, kind) = resolve_name(&name, &modules.search_paths)?; let Some((root_path, absolute_path, kind)) = resolve_name(&name, &modules.search_paths)
let normalized = absolute_path.canonicalize().ok()?; else {
return Ok(None);
};
let Ok(normalized) = absolute_path.canonicalize() else {
return Ok(None);
};
let file_id = db.file_id(&normalized); let file_id = db.file_id(&normalized);
let path = ModulePath::new(root_path.clone(), file_id); let path = ModulePath::new(root_path.clone(), file_id);
@ -277,7 +288,7 @@ where
entry.insert_entry(id); entry.insert_entry(id);
Some(id) Ok(Some(id))
} }
} }
} }
@ -286,7 +297,7 @@ where
/// ///
/// Returns `None` if the file is not a module in `sys.path`. /// Returns `None` if the file is not a module in `sys.path`.
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub fn file_to_module<Db>(db: &Db, file: FileId) -> Option<Module> pub fn file_to_module<Db>(db: &Db, file: FileId) -> QueryResult<Option<Module>>
where where
Db: SemanticDb + HasJar<SemanticJar>, Db: SemanticDb + HasJar<SemanticJar>,
{ {
@ -298,34 +309,42 @@ where
/// ///
/// Returns `None` if the path is not a module in `sys.path`. /// Returns `None` if the path is not a module in `sys.path`.
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub fn path_to_module<Db>(db: &Db, path: &Path) -> Option<Module> pub fn path_to_module<Db>(db: &Db, path: &Path) -> QueryResult<Option<Module>>
where where
Db: SemanticDb + HasJar<SemanticJar>, Db: SemanticDb + HasJar<SemanticJar>,
{ {
let jar = db.jar(); let jar = db.jar()?;
let modules = &jar.module_resolver; let modules = &jar.module_resolver;
debug_assert!(path.is_absolute()); debug_assert!(path.is_absolute());
if let Some(existing) = modules.by_path.get(path) { if let Some(existing) = modules.by_path.get(path) {
return Some(*existing); return Ok(Some(*existing));
} }
let (root_path, relative_path) = modules.search_paths.iter().find_map(|root| { let Some((root_path, relative_path)) = modules.search_paths.iter().find_map(|root| {
let relative_path = path.strip_prefix(root.path()).ok()?; let relative_path = path.strip_prefix(root.path()).ok()?;
Some((root.clone(), relative_path)) Some((root.clone(), relative_path))
})?; }) else {
return Ok(None);
};
let module_name = ModuleName::from_relative_path(relative_path)?; let Some(module_name) = ModuleName::from_relative_path(relative_path) else {
return Ok(None);
};
// Resolve the module name to see if Python would resolve the name to the same path. // Resolve the module name to see if Python would resolve the name to the same path.
// If it doesn't, then that means that multiple modules have the same in different // If it doesn't, then that means that multiple modules have the same in different
// root paths, but that the module corresponding to the past path is in a lower priority path, // root paths, but that the module corresponding to the past path is in a lower priority path,
// in which case we ignore it. // in which case we ignore it.
let module_id = resolve_module(db, module_name)?; let Some(module_id) = resolve_module(db, module_name)? else {
let module_path = module_id.path(db); return Ok(None);
};
let module_path = module_id.path(db)?;
if module_path.root() == &root_path { if module_path.root() == &root_path {
let normalized = path.canonicalize().ok()?; let Ok(normalized) = path.canonicalize() else {
return Ok(None);
};
let interned_normalized = db.file_id(&normalized); let interned_normalized = db.file_id(&normalized);
if interned_normalized != module_path.file() { if interned_normalized != module_path.file() {
@ -336,15 +355,15 @@ where
// ``` // ```
// The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`. // The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`.
// That means we need to ignore `src/foo.py` even though it resolves to the same module name. // That means we need to ignore `src/foo.py` even though it resolves to the same module name.
return None; return Ok(None);
} }
// Path has been inserted by `resolved` // Path has been inserted by `resolved`
Some(module_id) Ok(Some(module_id))
} else { } else {
// This path is for a module with the same name but in a module search path with a lower priority. // This path is for a module with the same name but in a module search path with a lower priority.
// Ignore it. // Ignore it.
None Ok(None)
} }
} }
@ -378,7 +397,7 @@ where
// TODO This needs tests // TODO This needs tests
// Note: Intentionally by-pass caching here. Module should not be in the cache yet. // Note: Intentionally by-pass caching here. Module should not be in the cache yet.
let module = path_to_module(db, path)?; let module = path_to_module(db, path).ok()??;
// The code below is to handle the addition of `__init__.py` files. // The code below is to handle the addition of `__init__.py` files.
// When an `__init__.py` file is added, we need to remove all modules that are part of the same package. // When an `__init__.py` file is added, we need to remove all modules that are part of the same package.
@ -392,7 +411,7 @@ where
return Some((module, Vec::new())); return Some((module, Vec::new()));
} }
let Some(parent_name) = module.name(db).parent() else { let Some(parent_name) = module.name(db).ok()?.parent() else {
return Some((module, Vec::new())); return Some((module, Vec::new()));
}; };
@ -691,7 +710,7 @@ mod tests {
} }
#[test] #[test]
fn first_party_module() -> std::io::Result<()> { fn first_party_module() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
src, src,
@ -702,22 +721,22 @@ mod tests {
let foo_path = src.path().join("foo.py"); let foo_path = src.path().join("foo.py");
std::fs::write(&foo_path, "print('Hello, world!')")?; std::fs::write(&foo_path, "print('Hello, world!')")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo"))); assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo"))?);
assert_eq!(ModuleName::new("foo"), foo_module.name(&db)); assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?);
assert_eq!(&src, foo_module.path(&db).root()); assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(ModuleKind::Module, foo_module.kind(&db)); assert_eq!(ModuleKind::Module, foo_module.kind(&db)?);
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file())); assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo_path)); assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn resolve_package() -> std::io::Result<()> { fn resolve_package() -> anyhow::Result<()> {
let TestCase { let TestCase {
src, src,
db, db,
@ -730,22 +749,22 @@ mod tests {
std::fs::create_dir(&foo_dir)?; std::fs::create_dir(&foo_dir)?;
std::fs::write(&foo_path, "print('Hello, world!')")?; std::fs::write(&foo_path, "print('Hello, world!')")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(ModuleName::new("foo"), foo_module.name(&db)); assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?);
assert_eq!(&src, foo_module.path(&db).root()); assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file())); assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo_path)); assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?);
// Resolving by directory doesn't resolve to the init file. // Resolving by directory doesn't resolve to the init file.
assert_eq!(None, db.path_to_module(&foo_dir)); assert_eq!(None, db.path_to_module(&foo_dir)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn package_priority_over_module() -> std::io::Result<()> { fn package_priority_over_module() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
temp_dir: _temp_dir, temp_dir: _temp_dir,
@ -761,20 +780,20 @@ mod tests {
let foo_py = src.path().join("foo.py"); let foo_py = src.path().join("foo.py");
std::fs::write(&foo_py, "print('Hello, world!')")?; std::fs::write(&foo_py, "print('Hello, world!')")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(&src, foo_module.path(&db).root()); assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db).file())); assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(ModuleKind::Package, foo_module.kind(&db)); assert_eq!(ModuleKind::Package, foo_module.kind(&db)?);
assert_eq!(Some(foo_module), db.path_to_module(&foo_init)); assert_eq!(Some(foo_module), db.path_to_module(&foo_init)?);
assert_eq!(None, db.path_to_module(&foo_py)); assert_eq!(None, db.path_to_module(&foo_py)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn typing_stub_over_module() -> std::io::Result<()> { fn typing_stub_over_module() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
src, src,
@ -787,19 +806,19 @@ mod tests {
std::fs::write(&foo_stub, "x: int")?; std::fs::write(&foo_stub, "x: int")?;
std::fs::write(&foo_py, "print('Hello, world!')")?; std::fs::write(&foo_py, "print('Hello, world!')")?;
let foo = db.resolve_module(ModuleName::new("foo")).unwrap(); let foo = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(&src, foo.path(&db).root()); assert_eq!(&src, foo.path(&db)?.root());
assert_eq!(&foo_stub, &*db.file_path(foo.path(&db).file())); assert_eq!(&foo_stub, &*db.file_path(foo.path(&db)?.file()));
assert_eq!(Some(foo), db.path_to_module(&foo_stub)); assert_eq!(Some(foo), db.path_to_module(&foo_stub)?);
assert_eq!(None, db.path_to_module(&foo_py)); assert_eq!(None, db.path_to_module(&foo_py)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn sub_packages() -> std::io::Result<()> { fn sub_packages() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
src, src,
@ -816,18 +835,18 @@ mod tests {
std::fs::write(bar.join("__init__.py"), "")?; std::fs::write(bar.join("__init__.py"), "")?;
std::fs::write(&baz, "print('Hello, world!')")?; std::fs::write(&baz, "print('Hello, world!')")?;
let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz")).unwrap(); let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz"))?.unwrap();
assert_eq!(&src, baz_module.path(&db).root()); assert_eq!(&src, baz_module.path(&db)?.root());
assert_eq!(&baz, &*db.file_path(baz_module.path(&db).file())); assert_eq!(&baz, &*db.file_path(baz_module.path(&db)?.file()));
assert_eq!(Some(baz_module), db.path_to_module(&baz)); assert_eq!(Some(baz_module), db.path_to_module(&baz)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn namespace_package() -> std::io::Result<()> { fn namespace_package() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
temp_dir: _, temp_dir: _,
@ -863,21 +882,21 @@ mod tests {
std::fs::write(&two, "print('Hello, world!')")?; std::fs::write(&two, "print('Hello, world!')")?;
let one_module = db let one_module = db
.resolve_module(ModuleName::new("parent.child.one")) .resolve_module(ModuleName::new("parent.child.one"))?
.unwrap(); .unwrap();
assert_eq!(Some(one_module), db.path_to_module(&one)); assert_eq!(Some(one_module), db.path_to_module(&one)?);
let two_module = db let two_module = db
.resolve_module(ModuleName::new("parent.child.two")) .resolve_module(ModuleName::new("parent.child.two"))?
.unwrap(); .unwrap();
assert_eq!(Some(two_module), db.path_to_module(&two)); assert_eq!(Some(two_module), db.path_to_module(&two)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn regular_package_in_namespace_package() -> std::io::Result<()> { fn regular_package_in_namespace_package() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
temp_dir: _, temp_dir: _,
@ -914,17 +933,20 @@ mod tests {
std::fs::write(two, "print('Hello, world!')")?; std::fs::write(two, "print('Hello, world!')")?;
let one_module = db let one_module = db
.resolve_module(ModuleName::new("parent.child.one")) .resolve_module(ModuleName::new("parent.child.one"))?
.unwrap(); .unwrap();
assert_eq!(Some(one_module), db.path_to_module(&one)); assert_eq!(Some(one_module), db.path_to_module(&one)?);
assert_eq!(None, db.resolve_module(ModuleName::new("parent.child.two"))); assert_eq!(
None,
db.resolve_module(ModuleName::new("parent.child.two"))?
);
Ok(()) Ok(())
} }
#[test] #[test]
fn module_search_path_priority() -> std::io::Result<()> { fn module_search_path_priority() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
src, src,
@ -938,20 +960,20 @@ mod tests {
std::fs::write(&foo_src, "")?; std::fs::write(&foo_src, "")?;
std::fs::write(&foo_site_packages, "")?; std::fs::write(&foo_site_packages, "")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(&src, foo_module.path(&db).root()); assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db).file())); assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo_src)); assert_eq!(Some(foo_module), db.path_to_module(&foo_src)?);
assert_eq!(None, db.path_to_module(&foo_site_packages)); assert_eq!(None, db.path_to_module(&foo_site_packages)?);
Ok(()) Ok(())
} }
#[test] #[test]
#[cfg(target_family = "unix")] #[cfg(target_family = "unix")]
fn symlink() -> std::io::Result<()> { fn symlink() -> anyhow::Result<()> {
let TestCase { let TestCase {
db, db,
src, src,
@ -965,28 +987,28 @@ mod tests {
std::fs::write(&foo, "")?; std::fs::write(&foo, "")?;
std::os::unix::fs::symlink(&foo, &bar)?; std::os::unix::fs::symlink(&foo, &bar)?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
let bar_module = db.resolve_module(ModuleName::new("bar")).unwrap(); let bar_module = db.resolve_module(ModuleName::new("bar"))?.unwrap();
assert_ne!(foo_module, bar_module); assert_ne!(foo_module, bar_module);
assert_eq!(&src, foo_module.path(&db).root()); assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo, &*db.file_path(foo_module.path(&db).file())); assert_eq!(&foo, &*db.file_path(foo_module.path(&db)?.file()));
// Bar has a different name but it should point to the same file. // Bar has a different name but it should point to the same file.
assert_eq!(&src, bar_module.path(&db).root()); assert_eq!(&src, bar_module.path(&db)?.root());
assert_eq!(foo_module.path(&db).file(), bar_module.path(&db).file()); assert_eq!(foo_module.path(&db)?.file(), bar_module.path(&db)?.file());
assert_eq!(&foo, &*db.file_path(bar_module.path(&db).file())); assert_eq!(&foo, &*db.file_path(bar_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo)); assert_eq!(Some(foo_module), db.path_to_module(&foo)?);
assert_eq!(Some(bar_module), db.path_to_module(&bar)); assert_eq!(Some(bar_module), db.path_to_module(&bar)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn resolve_dependency() -> std::io::Result<()> { fn resolve_dependency() -> anyhow::Result<()> {
let TestCase { let TestCase {
src, src,
db, db,
@ -1002,8 +1024,8 @@ mod tests {
std::fs::write(foo_path, "from .bar import test")?; std::fs::write(foo_path, "from .bar import test")?;
std::fs::write(bar_path, "test = 'Hello world'")?; std::fs::write(bar_path, "test = 'Hello world'")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap(); let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
let bar_module = db.resolve_module(ModuleName::new("foo.bar")).unwrap(); let bar_module = db.resolve_module(ModuleName::new("foo.bar"))?.unwrap();
// `from . import bar` in `foo/__init__.py` resolves to `foo` // `from . import bar` in `foo/__init__.py` resolves to `foo`
assert_eq!( assert_eq!(
@ -1014,13 +1036,13 @@ mod tests {
level: NonZeroU32::new(1).unwrap(), level: NonZeroU32::new(1).unwrap(),
module: None, module: None,
} }
) )?
); );
// `from baz import bar` in `foo/__init__.py` should resolve to `baz.py` // `from baz import bar` in `foo/__init__.py` should resolve to `baz.py`
assert_eq!( assert_eq!(
Some(ModuleName::new("baz")), Some(ModuleName::new("baz")),
foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz"))) foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))?
); );
// from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py` // from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py`
@ -1032,7 +1054,7 @@ mod tests {
level: NonZeroU32::new(1).unwrap(), level: NonZeroU32::new(1).unwrap(),
module: Some(ModuleName::new("bar")) module: Some(ModuleName::new("bar"))
} }
) )?
); );
// from .. import test in `foo/__init__.py` resolves to `` which is not a module // from .. import test in `foo/__init__.py` resolves to `` which is not a module
@ -1044,7 +1066,7 @@ mod tests {
level: NonZeroU32::new(2).unwrap(), level: NonZeroU32::new(2).unwrap(),
module: None module: None
} }
) )?
); );
// `from . import test` in `foo/bar.py` resolves to `foo` // `from . import test` in `foo/bar.py` resolves to `foo`
@ -1056,13 +1078,13 @@ mod tests {
level: NonZeroU32::new(1).unwrap(), level: NonZeroU32::new(1).unwrap(),
module: None module: None
} }
) )?
); );
// `from baz import test` in `foo/bar.py` resolves to `baz` // `from baz import test` in `foo/bar.py` resolves to `baz`
assert_eq!( assert_eq!(
Some(ModuleName::new("baz")), Some(ModuleName::new("baz")),
bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz"))) bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))?
); );
// `from .baz import test` in `foo/bar.py` resolves to `foo.baz`. // `from .baz import test` in `foo/bar.py` resolves to `foo.baz`.
@ -1074,7 +1096,7 @@ mod tests {
level: NonZeroU32::new(1).unwrap(), level: NonZeroU32::new(1).unwrap(),
module: Some(ModuleName::new("baz")) module: Some(ModuleName::new("baz"))
} }
) )?
); );
Ok(()) Ok(())

View File

@ -6,7 +6,7 @@ use ruff_python_parser::{Mode, ParseError};
use ruff_text_size::{Ranged, TextRange}; use ruff_text_size::{Ranged, TextRange};
use crate::cache::KeyValueCache; use crate::cache::KeyValueCache;
use crate::db::{HasJar, SourceDb, SourceJar}; use crate::db::{HasJar, QueryResult, SourceDb, SourceJar};
use crate::files::FileId; use crate::files::FileId;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -64,16 +64,16 @@ impl Parsed {
} }
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn parse<Db>(db: &Db, file_id: FileId) -> Parsed pub(crate) fn parse<Db>(db: &Db, file_id: FileId) -> QueryResult<Parsed>
where where
Db: SourceDb + HasJar<SourceJar>, Db: SourceDb + HasJar<SourceJar>,
{ {
let parsed = db.jar(); let parsed = db.jar()?;
parsed.parsed.get(&file_id, |file_id| { parsed.parsed.get(&file_id, |file_id| {
let source = db.source(*file_id); let source = db.source(*file_id)?;
Parsed::from_text(source.text()) Ok(Parsed::from_text(source.text()))
}) })
} }

View File

@ -1,10 +1,9 @@
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use rayon::max_num_threads; use rayon::{current_num_threads, yield_local};
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
use crate::cancellation::CancellationToken; use crate::db::{Database, QueryError, QueryResult, SemanticDb, SourceDb};
use crate::db::{SemanticDb, SourceDb};
use crate::files::FileId; use crate::files::FileId;
use crate::lint::Diagnostics; use crate::lint::Diagnostics;
use crate::program::Program; use crate::program::Program;
@ -13,42 +12,37 @@ use crate::symbols::Dependency;
impl Program { impl Program {
/// Checks all open files in the workspace and its dependencies. /// Checks all open files in the workspace and its dependencies.
#[tracing::instrument(level = "debug", skip_all)] #[tracing::instrument(level = "debug", skip_all)]
pub fn check( pub fn check(&self, scheduler: &dyn CheckScheduler) -> QueryResult<Vec<String>> {
&self, self.cancelled()?;
scheduler: &dyn CheckScheduler,
cancellation_token: CancellationToken, let check_loop = CheckFilesLoop::new(scheduler);
) -> Result<Vec<String>, CheckError> {
let check_loop = CheckFilesLoop::new(scheduler, cancellation_token);
check_loop.run(self.workspace().open_files.iter().copied()) check_loop.run(self.workspace().open_files.iter().copied())
} }
/// Checks a single file and its dependencies. /// Checks a single file and its dependencies.
#[tracing::instrument(level = "debug", skip(self, scheduler, cancellation_token))] #[tracing::instrument(level = "debug", skip(self, scheduler))]
pub fn check_file( pub fn check_file(
&self, &self,
file: FileId, file: FileId,
scheduler: &dyn CheckScheduler, scheduler: &dyn CheckScheduler,
cancellation_token: CancellationToken, ) -> QueryResult<Vec<String>> {
) -> Result<Vec<String>, CheckError> { self.cancelled()?;
let check_loop = CheckFilesLoop::new(scheduler, cancellation_token);
let check_loop = CheckFilesLoop::new(scheduler);
check_loop.run([file].into_iter()) check_loop.run([file].into_iter())
} }
#[tracing::instrument(level = "debug", skip(self, context))] #[tracing::instrument(level = "debug", skip(self, context))]
fn do_check_file( fn do_check_file(&self, file: FileId, context: &CheckContext) -> QueryResult<Diagnostics> {
&self, self.cancelled()?;
file: FileId,
context: &CheckContext,
) -> Result<Diagnostics, CheckError> {
context.cancelled_ok()?;
let symbol_table = self.symbol_table(file); let symbol_table = self.symbol_table(file)?;
let dependencies = symbol_table.dependencies(); let dependencies = symbol_table.dependencies();
if !dependencies.is_empty() { if !dependencies.is_empty() {
let module = self.file_to_module(file); let module = self.file_to_module(file)?;
// TODO scheduling all dependencies here is wasteful if we don't infer any types on them // TODO scheduling all dependencies here is wasteful if we don't infer any types on them
// but I think that's unlikely, so it is okay? // but I think that's unlikely, so it is okay?
@ -57,18 +51,19 @@ impl Program {
for dependency in dependencies { for dependency in dependencies {
let dependency_name = match dependency { let dependency_name = match dependency {
Dependency::Module(name) => Some(name.clone()), Dependency::Module(name) => Some(name.clone()),
Dependency::Relative { .. } => module Dependency::Relative { .. } => match &module {
.as_ref() Some(module) => module.resolve_dependency(self, dependency)?,
.and_then(|module| module.resolve_dependency(self, dependency)), None => None,
},
}; };
if let Some(dependency_name) = dependency_name { if let Some(dependency_name) = dependency_name {
// TODO We may want to have a different check functions for non-first-party // TODO We may want to have a different check functions for non-first-party
// files because we only need to index them and not check them. // files because we only need to index them and not check them.
// Supporting non-first-party code also requires supporting typing stubs. // Supporting non-first-party code also requires supporting typing stubs.
if let Some(dependency) = self.resolve_module(dependency_name) { if let Some(dependency) = self.resolve_module(dependency_name)? {
if dependency.path(self).root().kind().is_first_party() { if dependency.path(self)?.root().kind().is_first_party() {
context.schedule_check_file(dependency.path(self).file()); context.schedule_check_file(dependency.path(self)?.file());
} }
} }
} }
@ -78,8 +73,8 @@ impl Program {
let mut diagnostics = Vec::new(); let mut diagnostics = Vec::new();
if self.workspace().is_file_open(file) { if self.workspace().is_file_open(file) {
diagnostics.extend_from_slice(&self.lint_syntax(file)); diagnostics.extend_from_slice(&self.lint_syntax(file)?);
diagnostics.extend_from_slice(&self.lint_semantic(file)); diagnostics.extend_from_slice(&self.lint_semantic(file)?);
} }
Ok(Diagnostics::from(diagnostics)) Ok(Diagnostics::from(diagnostics))
@ -128,10 +123,18 @@ where
self.scope self.scope
.spawn(move |_| child_span.in_scope(|| check_file_task.run(program))); .spawn(move |_| child_span.in_scope(|| check_file_task.run(program)));
if current_num_threads() == 1 {
yield_local();
}
} }
fn max_concurrency(&self) -> Option<NonZeroUsize> { fn max_concurrency(&self) -> Option<NonZeroUsize> {
Some(NonZeroUsize::new(max_num_threads()).unwrap_or(NonZeroUsize::MIN)) if current_num_threads() == 1 {
return None;
}
Some(NonZeroUsize::new(current_num_threads()).unwrap_or(NonZeroUsize::MIN))
} }
} }
@ -156,11 +159,6 @@ impl CheckScheduler for SameThreadCheckScheduler<'_> {
} }
} }
#[derive(Debug, Clone)]
pub enum CheckError {
Cancelled,
}
#[derive(Debug)] #[derive(Debug)]
pub struct CheckFileTask { pub struct CheckFileTask {
file_id: FileId, file_id: FileId,
@ -176,7 +174,7 @@ impl CheckFileTask {
.sender .sender
.send(CheckFileMessage::Completed(diagnostics)) .send(CheckFileMessage::Completed(diagnostics))
.unwrap(), .unwrap(),
Err(CheckError::Cancelled) => self Err(QueryError::Cancelled) => self
.context .context
.sender .sender
.send(CheckFileMessage::Cancelled) .send(CheckFileMessage::Cancelled)
@ -187,19 +185,12 @@ impl CheckFileTask {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct CheckContext { struct CheckContext {
cancellation_token: CancellationToken, sender: crossbeam::channel::Sender<CheckFileMessage>,
sender: crossbeam_channel::Sender<CheckFileMessage>,
} }
impl CheckContext { impl CheckContext {
fn new( fn new(sender: crossbeam::channel::Sender<CheckFileMessage>) -> Self {
cancellation_token: CancellationToken, Self { sender }
sender: crossbeam_channel::Sender<CheckFileMessage>,
) -> Self {
Self {
cancellation_token,
sender,
}
} }
/// Queues a new file for checking using the [`CheckScheduler`]. /// Queues a new file for checking using the [`CheckScheduler`].
@ -207,52 +198,36 @@ impl CheckContext {
fn schedule_check_file(&self, file_id: FileId) { fn schedule_check_file(&self, file_id: FileId) {
self.sender.send(CheckFileMessage::Queue(file_id)).unwrap(); self.sender.send(CheckFileMessage::Queue(file_id)).unwrap();
} }
/// Returns `true` if the check has been cancelled.
fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
fn cancelled_ok(&self) -> Result<(), CheckError> {
if self.is_cancelled() {
Err(CheckError::Cancelled)
} else {
Ok(())
}
}
} }
struct CheckFilesLoop<'a> { struct CheckFilesLoop<'a> {
scheduler: &'a dyn CheckScheduler, scheduler: &'a dyn CheckScheduler,
cancellation_token: CancellationToken,
pending: usize, pending: usize,
queued_files: FxHashSet<FileId>, queued_files: FxHashSet<FileId>,
} }
impl<'a> CheckFilesLoop<'a> { impl<'a> CheckFilesLoop<'a> {
fn new(scheduler: &'a dyn CheckScheduler, cancellation_token: CancellationToken) -> Self { fn new(scheduler: &'a dyn CheckScheduler) -> Self {
Self { Self {
scheduler, scheduler,
cancellation_token,
queued_files: FxHashSet::default(), queued_files: FxHashSet::default(),
pending: 0, pending: 0,
} }
} }
fn run(mut self, files: impl Iterator<Item = FileId>) -> Result<Vec<String>, CheckError> { fn run(mut self, files: impl Iterator<Item = FileId>) -> QueryResult<Vec<String>> {
let (sender, receiver) = if let Some(max_concurrency) = self.scheduler.max_concurrency() { let (sender, receiver) = if let Some(max_concurrency) = self.scheduler.max_concurrency() {
crossbeam_channel::bounded(max_concurrency.get()) crossbeam::channel::bounded(max_concurrency.get())
} else { } else {
// The checks run on the current thread. That means it is necessary to store all messages // The checks run on the current thread. That means it is necessary to store all messages
// or we risk deadlocking when the main loop never gets a chance to read the messages. // or we risk deadlocking when the main loop never gets a chance to read the messages.
crossbeam_channel::unbounded() crossbeam::channel::unbounded()
}; };
let context = CheckContext::new(self.cancellation_token.clone(), sender.clone()); let context = CheckContext::new(sender.clone());
for file in files { for file in files {
self.queue_file(file, context.clone())?; self.queue_file(file, context.clone());
} }
self.run_impl(receiver, &context) self.run_impl(receiver, &context)
@ -260,14 +235,11 @@ impl<'a> CheckFilesLoop<'a> {
fn run_impl( fn run_impl(
mut self, mut self,
receiver: crossbeam_channel::Receiver<CheckFileMessage>, receiver: crossbeam::channel::Receiver<CheckFileMessage>,
context: &CheckContext, context: &CheckContext,
) -> Result<Vec<String>, CheckError> { ) -> QueryResult<Vec<String>> {
if self.cancellation_token.is_cancelled() {
return Err(CheckError::Cancelled);
}
let mut result = Vec::default(); let mut result = Vec::default();
let mut cancelled = false;
for message in receiver { for message in receiver {
match message { match message {
@ -281,30 +253,35 @@ impl<'a> CheckFilesLoop<'a> {
} }
} }
CheckFileMessage::Queue(id) => { CheckFileMessage::Queue(id) => {
self.queue_file(id, context.clone())?; if !cancelled {
self.queue_file(id, context.clone());
}
} }
CheckFileMessage::Cancelled => { CheckFileMessage::Cancelled => {
return Err(CheckError::Cancelled); self.pending -= 1;
cancelled = true;
if self.pending == 0 {
break;
}
} }
} }
} }
if cancelled {
Err(QueryError::Cancelled)
} else {
Ok(result) Ok(result)
} }
fn queue_file(&mut self, file_id: FileId, context: CheckContext) -> Result<(), CheckError> {
if context.is_cancelled() {
return Err(CheckError::Cancelled);
} }
fn queue_file(&mut self, file_id: FileId, context: CheckContext) {
if self.queued_files.insert(file_id) { if self.queued_files.insert(file_id) {
self.pending += 1; self.pending += 1;
self.scheduler self.scheduler
.check_file(CheckFileTask { file_id, context }); .check_file(CheckFileTask { file_id, context });
} }
Ok(())
} }
} }

View File

@ -1,45 +1,35 @@
pub mod check;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use crate::db::{Db, HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar}; use crate::db::{
use crate::files::{FileId, Files}; Database, Db, DbRuntime, HasJar, HasJars, JarsStorage, ParallelDatabase, QueryResult,
use crate::lint::{ SemanticDb, SemanticJar, Snapshot, SourceDb, SourceJar,
lint_semantic, lint_syntax, Diagnostics, LintSemanticStorage, LintSyntaxStorage,
}; };
use crate::files::{FileId, Files};
use crate::lint::{lint_semantic, lint_syntax, Diagnostics};
use crate::module::{ use crate::module::{
add_module, file_to_module, path_to_module, resolve_module, set_module_search_paths, Module, add_module, file_to_module, path_to_module, resolve_module, set_module_search_paths, Module,
ModuleData, ModuleName, ModuleResolver, ModuleSearchPath, ModuleData, ModuleName, ModuleSearchPath,
}; };
use crate::parse::{parse, Parsed, ParsedStorage}; use crate::parse::{parse, Parsed};
use crate::source::{source_text, Source, SourceStorage}; use crate::source::{source_text, Source};
use crate::symbols::{symbol_table, SymbolId, SymbolTable, SymbolTablesStorage}; use crate::symbols::{symbol_table, SymbolId, SymbolTable};
use crate::types::{infer_symbol_type, Type, TypeStore}; use crate::types::{infer_symbol_type, Type};
use crate::Workspace; use crate::Workspace;
pub mod check;
#[derive(Debug)] #[derive(Debug)]
pub struct Program { pub struct Program {
jars: JarsStorage<Program>,
files: Files, files: Files,
source: SourceJar,
semantic: SemanticJar,
workspace: Workspace, workspace: Workspace,
} }
impl Program { impl Program {
pub fn new(workspace: Workspace, module_search_paths: Vec<ModuleSearchPath>) -> Self { pub fn new(workspace: Workspace) -> Self {
Self { Self {
source: SourceJar { jars: JarsStorage::default(),
sources: SourceStorage::default(),
parsed: ParsedStorage::default(),
lint_syntax: LintSyntaxStorage::default(),
},
semantic: SemanticJar {
module_resolver: ModuleResolver::new(module_search_paths),
symbol_tables: SymbolTablesStorage::default(),
type_store: TypeStore::default(),
lint_semantic: LintSemanticStorage::default(),
},
files: Files::default(), files: Files::default(),
workspace, workspace,
} }
@ -49,17 +39,19 @@ impl Program {
where where
I: IntoIterator<Item = FileChange>, I: IntoIterator<Item = FileChange>,
{ {
let files = self.files.clone();
let (source, semantic) = self.jars_mut();
for change in changes { for change in changes {
self.semantic let file_path = files.path(change.id);
.module_resolver
.remove_module(&self.file_path(change.id)); semantic.module_resolver.remove_module(&file_path);
self.semantic.symbol_tables.remove(&change.id); semantic.symbol_tables.remove(&change.id);
self.source.sources.remove(&change.id); source.sources.remove(&change.id);
self.source.parsed.remove(&change.id); source.parsed.remove(&change.id);
self.source.lint_syntax.remove(&change.id); source.lint_syntax.remove(&change.id);
// TODO: remove all dependent modules as well // TODO: remove all dependent modules as well
self.semantic.type_store.remove_module(change.id); semantic.type_store.remove_module(change.id);
self.semantic.lint_semantic.remove(&change.id); semantic.lint_semantic.remove(&change.id);
} }
} }
@ -85,41 +77,41 @@ impl SourceDb for Program {
self.files.path(file_id) self.files.path(file_id)
} }
fn source(&self, file_id: FileId) -> Source { fn source(&self, file_id: FileId) -> QueryResult<Source> {
source_text(self, file_id) source_text(self, file_id)
} }
fn parse(&self, file_id: FileId) -> Parsed { fn parse(&self, file_id: FileId) -> QueryResult<Parsed> {
parse(self, file_id) parse(self, file_id)
} }
fn lint_syntax(&self, file_id: FileId) -> Diagnostics { fn lint_syntax(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_syntax(self, file_id) lint_syntax(self, file_id)
} }
} }
impl SemanticDb for Program { impl SemanticDb for Program {
fn resolve_module(&self, name: ModuleName) -> Option<Module> { fn resolve_module(&self, name: ModuleName) -> QueryResult<Option<Module>> {
resolve_module(self, name) resolve_module(self, name)
} }
fn file_to_module(&self, file_id: FileId) -> Option<Module> { fn file_to_module(&self, file_id: FileId) -> QueryResult<Option<Module>> {
file_to_module(self, file_id) file_to_module(self, file_id)
} }
fn path_to_module(&self, path: &Path) -> Option<Module> { fn path_to_module(&self, path: &Path) -> QueryResult<Option<Module>> {
path_to_module(self, path) path_to_module(self, path)
} }
fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable> { fn symbol_table(&self, file_id: FileId) -> QueryResult<Arc<SymbolTable>> {
symbol_table(self, file_id) symbol_table(self, file_id)
} }
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type { fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(self, file_id, symbol_id) infer_symbol_type(self, file_id, symbol_id)
} }
fn lint_semantic(&self, file_id: FileId) -> Diagnostics { fn lint_semantic(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_semantic(self, file_id) lint_semantic(self, file_id)
} }
@ -135,23 +127,55 @@ impl SemanticDb for Program {
impl Db for Program {} impl Db for Program {}
impl Database for Program {
fn runtime(&self) -> &DbRuntime {
self.jars.runtime()
}
fn runtime_mut(&mut self) -> &mut DbRuntime {
self.jars.runtime_mut()
}
}
impl ParallelDatabase for Program {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(Self {
jars: self.jars.snapshot(),
files: self.files.clone(),
workspace: self.workspace.clone(),
})
}
}
impl HasJars for Program {
type Jars = (SourceJar, SemanticJar);
fn jars(&self) -> QueryResult<&Self::Jars> {
self.jars.jars()
}
fn jars_mut(&mut self) -> &mut Self::Jars {
self.jars.jars_mut()
}
}
impl HasJar<SourceJar> for Program { impl HasJar<SourceJar> for Program {
fn jar(&self) -> &SourceJar { fn jar(&self) -> QueryResult<&SourceJar> {
&self.source Ok(&self.jars()?.0)
} }
fn jar_mut(&mut self) -> &mut SourceJar { fn jar_mut(&mut self) -> &mut SourceJar {
&mut self.source &mut self.jars_mut().0
} }
} }
impl HasJar<SemanticJar> for Program { impl HasJar<SemanticJar> for Program {
fn jar(&self) -> &SemanticJar { fn jar(&self) -> QueryResult<&SemanticJar> {
&self.semantic Ok(&self.jars()?.1)
} }
fn jar_mut(&mut self) -> &mut SemanticJar { fn jar_mut(&mut self) -> &mut SemanticJar {
&mut self.semantic &mut self.jars_mut().1
} }
} }

View File

@ -1,5 +1,5 @@
use crate::cache::KeyValueCache; use crate::cache::KeyValueCache;
use crate::db::{HasJar, SourceDb, SourceJar}; use crate::db::{HasJar, QueryResult, SourceDb, SourceJar};
use ruff_notebook::Notebook; use ruff_notebook::Notebook;
use ruff_python_ast::PySourceType; use ruff_python_ast::PySourceType;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
@ -8,11 +8,11 @@ use std::sync::Arc;
use crate::files::FileId; use crate::files::FileId;
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn source_text<Db>(db: &Db, file_id: FileId) -> Source pub(crate) fn source_text<Db>(db: &Db, file_id: FileId) -> QueryResult<Source>
where where
Db: SourceDb + HasJar<SourceJar>, Db: SourceDb + HasJar<SourceJar>,
{ {
let sources = &db.jar().sources; let sources = &db.jar()?.sources;
sources.get(&file_id, |file_id| { sources.get(&file_id, |file_id| {
let path = db.file_path(*file_id); let path = db.file_path(*file_id);
@ -43,7 +43,7 @@ where
} }
}; };
Source { kind } Ok(Source { kind })
}) })
} }

View File

@ -16,22 +16,22 @@ use ruff_python_ast::visitor::preorder::PreorderVisitor;
use crate::ast_ids::TypedNodeKey; use crate::ast_ids::TypedNodeKey;
use crate::cache::KeyValueCache; use crate::cache::KeyValueCache;
use crate::db::{HasJar, SemanticDb, SemanticJar}; use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId; use crate::files::FileId;
use crate::module::ModuleName; use crate::module::ModuleName;
use crate::Name; use crate::Name;
#[allow(unreachable_pub)] #[allow(unreachable_pub)]
#[tracing::instrument(level = "debug", skip(db))] #[tracing::instrument(level = "debug", skip(db))]
pub fn symbol_table<Db>(db: &Db, file_id: FileId) -> Arc<SymbolTable> pub fn symbol_table<Db>(db: &Db, file_id: FileId) -> QueryResult<Arc<SymbolTable>>
where where
Db: SemanticDb + HasJar<SemanticJar>, Db: SemanticDb + HasJar<SemanticJar>,
{ {
let jar = db.jar(); let jar = db.jar()?;
jar.symbol_tables.get(&file_id, |_| { jar.symbol_tables.get(&file_id, |_| {
let parsed = db.parse(file_id); let parsed = db.parse(file_id)?;
Arc::from(SymbolTable::from_ast(parsed.ast())) Ok(Arc::from(SymbolTable::from_ast(parsed.ast())))
}) })
} }

View File

@ -2,7 +2,7 @@
use ruff_python_ast::AstNode; use ruff_python_ast::AstNode;
use crate::db::{HasJar, SemanticDb, SemanticJar}; use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::module::ModuleName; use crate::module::ModuleName;
use crate::symbols::{Definition, ImportFromDefinition, SymbolId}; use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
use crate::types::Type; use crate::types::Type;
@ -11,23 +11,24 @@ use ruff_python_ast as ast;
// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
#[tracing::instrument(level = "trace", skip(db))] #[tracing::instrument(level = "trace", skip(db))]
pub fn infer_symbol_type<Db>(db: &Db, file_id: FileId, symbol_id: SymbolId) -> Type pub fn infer_symbol_type<Db>(db: &Db, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type>
where where
Db: SemanticDb + HasJar<SemanticJar>, Db: SemanticDb + HasJar<SemanticJar>,
{ {
let symbols = db.symbol_table(file_id); let symbols = db.symbol_table(file_id)?;
let defs = symbols.definitions(symbol_id); let defs = symbols.definitions(symbol_id);
if let Some(ty) = db if let Some(ty) = db
.jar() .jar()?
.type_store .type_store
.get_cached_symbol_type(file_id, symbol_id) .get_cached_symbol_type(file_id, symbol_id)
{ {
return ty; return Ok(ty);
} }
// TODO handle multiple defs, conditional defs... // TODO handle multiple defs, conditional defs...
assert_eq!(defs.len(), 1); assert_eq!(defs.len(), 1);
let type_store = &db.jar()?.type_store;
let ty = match &defs[0] { let ty = match &defs[0] {
Definition::ImportFrom(ImportFromDefinition { Definition::ImportFrom(ImportFromDefinition {
@ -38,11 +39,11 @@ where
// TODO relative imports // TODO relative imports
assert!(matches!(level, 0)); assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(module) = db.resolve_module(module_name) { if let Some(module) = db.resolve_module(module_name)? {
let remote_file_id = module.path(db).file(); let remote_file_id = module.path(db)?.file();
let remote_symbols = db.symbol_table(remote_file_id); let remote_symbols = db.symbol_table(remote_file_id)?;
if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) { if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) {
db.infer_symbol_type(remote_file_id, remote_symbol_id) db.infer_symbol_type(remote_file_id, remote_symbol_id)?
} else { } else {
Type::Unknown Type::Unknown
} }
@ -50,71 +51,68 @@ where
Type::Unknown Type::Unknown
} }
} }
Definition::ClassDef(node_key) => db Definition::ClassDef(node_key) => {
.jar() if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
.type_store ty
.get_cached_node_type(file_id, node_key.erased()) } else {
.unwrap_or_else(|| { let parsed = db.parse(file_id)?;
let parsed = db.parse(file_id);
let ast = parsed.ast(); let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref()); let node = node_key.resolve_unwrap(ast.as_any_node_ref());
let bases: Vec<_> = node let mut bases = Vec::with_capacity(node.bases().len());
.bases()
.iter()
.map(|base_expr| infer_expr_type(db, file_id, base_expr))
.collect();
let store = &db.jar().type_store; for base in node.bases() {
let ty = Type::Class(store.add_class(file_id, &node.name.id, bases)); bases.push(infer_expr_type(db, file_id, base)?);
store.cache_node_type(file_id, *node_key.erased(), ty); }
let ty = Type::Class(type_store.add_class(file_id, &node.name.id, bases));
type_store.cache_node_type(file_id, *node_key.erased(), ty);
ty ty
}), }
Definition::FunctionDef(node_key) => db }
.jar() Definition::FunctionDef(node_key) => {
.type_store if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
.get_cached_node_type(file_id, node_key.erased()) ty
.unwrap_or_else(|| { } else {
let parsed = db.parse(file_id); let parsed = db.parse(file_id)?;
let ast = parsed.ast(); let ast = parsed.ast();
let node = node_key let node = node_key
.resolve(ast.as_any_node_ref()) .resolve(ast.as_any_node_ref())
.expect("node key should resolve"); .expect("node key should resolve");
let store = &db.jar().type_store; let ty = type_store.add_function(file_id, &node.name.id).into();
let ty = store.add_function(file_id, &node.name.id).into(); type_store.cache_node_type(file_id, *node_key.erased(), ty);
store.cache_node_type(file_id, *node_key.erased(), ty);
ty ty
}), }
}
Definition::Assignment(node_key) => { Definition::Assignment(node_key) => {
let parsed = db.parse(file_id); let parsed = db.parse(file_id)?;
let ast = parsed.ast(); let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref()); let node = node_key.resolve_unwrap(ast.as_any_node_ref());
// TODO handle unpacking assignment correctly // TODO handle unpacking assignment correctly
infer_expr_type(db, file_id, &node.value) infer_expr_type(db, file_id, &node.value)?
} }
_ => todo!("other kinds of definitions"), _ => todo!("other kinds of definitions"),
}; };
db.jar() type_store.cache_symbol_type(file_id, symbol_id, ty);
.type_store
.cache_symbol_type(file_id, symbol_id, ty);
// TODO record dependencies // TODO record dependencies
ty Ok(ty)
} }
fn infer_expr_type<Db>(db: &Db, file_id: FileId, expr: &ast::Expr) -> Type fn infer_expr_type<Db>(db: &Db, file_id: FileId, expr: &ast::Expr) -> QueryResult<Type>
where where
Db: SemanticDb + HasJar<SemanticJar>, Db: SemanticDb + HasJar<SemanticJar>,
{ {
// TODO cache the resolution of the type on the node // TODO cache the resolution of the type on the node
let symbols = db.symbol_table(file_id); let symbols = db.symbol_table(file_id)?;
match expr { match expr {
ast::Expr::Name(name) => { ast::Expr::Name(name) => {
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
db.infer_symbol_type(file_id, symbol_id) db.infer_symbol_type(file_id, symbol_id)
} else { } else {
Type::Unknown Ok(Type::Unknown)
} }
} }
_ => todo!("full expression type resolution"), _ => todo!("full expression type resolution"),
@ -154,7 +152,7 @@ mod tests {
} }
#[test] #[test]
fn follow_import_to_class() -> std::io::Result<()> { fn follow_import_to_class() -> anyhow::Result<()> {
let case = create_test()?; let case = create_test()?;
let db = &case.db; let db = &case.db;
@ -163,18 +161,18 @@ mod tests {
std::fs::write(a_path, "from b import C as D; E = D")?; std::fs::write(a_path, "from b import C as D; E = D")?;
std::fs::write(b_path, "class C: pass")?; std::fs::write(b_path, "class C: pass")?;
let a_file = db let a_file = db
.resolve_module(ModuleName::new("a")) .resolve_module(ModuleName::new("a"))?
.expect("module should be found") .expect("module should be found")
.path(db) .path(db)?
.file(); .file();
let a_syms = db.symbol_table(a_file); let a_syms = db.symbol_table(a_file)?;
let e_sym = a_syms let e_sym = a_syms
.root_symbol_id_by_name("E") .root_symbol_id_by_name("E")
.expect("E symbol should be found"); .expect("E symbol should be found");
let ty = db.infer_symbol_type(a_file, e_sym); let ty = db.infer_symbol_type(a_file, e_sym)?;
let jar = HasJar::<SemanticJar>::jar(db); let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::Class(_))); assert!(matches!(ty, Type::Class(_)));
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]"); assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
@ -182,28 +180,28 @@ mod tests {
} }
#[test] #[test]
fn resolve_base_class_by_name() -> std::io::Result<()> { fn resolve_base_class_by_name() -> anyhow::Result<()> {
let case = create_test()?; let case = create_test()?;
let db = &case.db; let db = &case.db;
let path = case.src.path().join("mod.py"); let path = case.src.path().join("mod.py");
std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?; std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?;
let file = db let file = db
.resolve_module(ModuleName::new("mod")) .resolve_module(ModuleName::new("mod"))?
.expect("module should be found") .expect("module should be found")
.path(db) .path(db)?
.file(); .file();
let syms = db.symbol_table(file); let syms = db.symbol_table(file)?;
let sym = syms let sym = syms
.root_symbol_id_by_name("Sub") .root_symbol_id_by_name("Sub")
.expect("Sub symbol should be found"); .expect("Sub symbol should be found");
let ty = db.infer_symbol_type(file, sym); let ty = db.infer_symbol_type(file, sym)?;
let Type::Class(class_id) = ty else { let Type::Class(class_id) = ty else {
panic!("Sub is not a Class") panic!("Sub is not a Class")
}; };
let jar = HasJar::<SemanticJar>::jar(db); let jar = HasJar::<SemanticJar>::jar(db)?;
let base_names: Vec<_> = jar let base_names: Vec<_> = jar
.type_store .type_store
.get_class(class_id) .get_class(class_id)

View File

@ -26,7 +26,7 @@ ruff_text_size = { path = "../ruff_text_size" }
ruff_workspace = { path = "../ruff_workspace" } ruff_workspace = { path = "../ruff_workspace" }
anyhow = { workspace = true } anyhow = { workspace = true }
crossbeam-channel = { workspace = true } crossbeam = { workspace = true }
jod-thread = { workspace = true } jod-thread = { workspace = true }
libc = { workspace = true } libc = { workspace = true }
lsp-server = { workspace = true } lsp-server = { workspace = true }

View File

@ -6,7 +6,7 @@ use serde_json::Value;
use super::schedule::Task; use super::schedule::Task;
pub(crate) type ClientSender = crossbeam_channel::Sender<lsp_server::Message>; pub(crate) type ClientSender = crossbeam::channel::Sender<lsp_server::Message>;
type ResponseBuilder<'s> = Box<dyn FnOnce(lsp_server::Response) -> Task<'s>>; type ResponseBuilder<'s> = Box<dyn FnOnce(lsp_server::Response) -> Task<'s>>;

View File

@ -1,6 +1,6 @@
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use crossbeam_channel::Sender; use crossbeam::channel::Sender;
use crate::session::Session; use crate::session::Session;

View File

@ -21,7 +21,7 @@ use std::{
}, },
}; };
use crossbeam_channel::{Receiver, Sender}; use crossbeam::channel::{Receiver, Sender};
use super::{Builder, JoinHandle, ThreadPriority}; use super::{Builder, JoinHandle, ThreadPriority};
@ -52,7 +52,7 @@ impl Pool {
let threads = usize::from(threads); let threads = usize::from(threads);
// Channel buffer capacity is between 2 and 4, depending on the pool size. // Channel buffer capacity is between 2 and 4, depending on the pool size.
let (job_sender, job_receiver) = crossbeam_channel::bounded(std::cmp::min(threads * 2, 4)); let (job_sender, job_receiver) = crossbeam::channel::bounded(std::cmp::min(threads * 2, 4));
let extant_tasks = Arc::new(AtomicUsize::new(0)); let extant_tasks = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::with_capacity(threads); let mut handles = Vec::with_capacity(threads);