diff --git a/crates/red_knot/src/cancellation.rs b/crates/red_knot/src/cancellation.rs index 9f5e57cf56..05610ce279 100644 --- a/crates/red_knot/src/cancellation.rs +++ b/crates/red_knot/src/cancellation.rs @@ -1,17 +1,18 @@ use std::sync::{Arc, Condvar, Mutex}; #[derive(Debug, Default)] -pub struct CancellationSource { +pub struct CancellationTokenSource { signal: Arc<(Mutex, Condvar)>, } -impl CancellationSource { +impl CancellationTokenSource { pub fn new() -> Self { Self { signal: Arc::new((Mutex::new(false), Condvar::default())), } } + #[tracing::instrument(level = "trace")] pub fn cancel(&self) { let (cancelled, condvar) = &*self.signal; diff --git a/crates/red_knot/src/files.rs b/crates/red_knot/src/files.rs index fc7f18115f..defe4a2ee4 100644 --- a/crates/red_knot/src/files.rs +++ b/crates/red_knot/src/files.rs @@ -23,7 +23,7 @@ pub struct Files { } impl Files { - #[tracing::instrument(level = "trace", skip(path))] + #[tracing::instrument(level = "debug", skip(path))] pub fn intern(&self, path: &Path) -> FileId { self.inner.write().intern(path) } @@ -32,7 +32,7 @@ impl Files { self.inner.read().try_get(path) } - // TODO Can we avoid using an `Arc` here? salsa can return references for some reason. + #[tracing::instrument(level = "debug")] pub fn path(&self, id: FileId) -> Arc { self.inner.read().path(id) } diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index d42bb51a54..d787e00d8c 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -8,6 +8,7 @@ use crate::cache::KeyValueCache; use crate::db::{HasJar, SourceDb, SourceJar}; use crate::files::FileId; +#[tracing::instrument(level = "debug", skip(db))] pub(crate) fn lint_syntax(db: &Db, file_id: FileId) -> Diagnostics where Db: SourceDb + HasJar, diff --git a/crates/red_knot/src/main.rs b/crates/red_knot/src/main.rs index 9bfb27bd0b..016cb5a3ea 100644 --- a/crates/red_knot/src/main.rs +++ b/crates/red_knot/src/main.rs @@ -1,8 +1,8 @@ +#![allow(clippy::dbg_macro)] + use std::collections::hash_map::Entry; -use std::num::NonZeroUsize; use std::path::Path; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use rustc_hash::FxHashMap; use tracing::subscriber::Interest; @@ -12,20 +12,16 @@ use tracing_subscriber::layer::{Context, Filter, SubscriberExt}; use tracing_subscriber::{Layer, Registry}; use tracing_tree::time::Uptime; -use red_knot::cancellation::CancellationSource; +use red_knot::cancellation::CancellationTokenSource; use red_knot::db::{HasJar, SourceDb, SourceJar}; use red_knot::files::FileId; use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind}; +use red_knot::program::check::{CheckError, RayonCheckScheduler}; use red_knot::program::{FileChange, FileChangeKind, Program}; use red_knot::watch::FileWatcher; -use red_knot::{files, Workspace}; +use red_knot::Workspace; -#[allow( - clippy::dbg_macro, - clippy::print_stdout, - clippy::unnecessary_wraps, - clippy::print_stderr -)] +#[allow(clippy::print_stdout, clippy::unnecessary_wraps, clippy::print_stderr)] fn main() -> anyhow::Result<()> { setup_tracing(); @@ -48,212 +44,43 @@ fn main() -> anyhow::Result<()> { return Err(anyhow::anyhow!("Invalid arguments")); } - let files = files::Files::default(); let workspace_folder = entry_point.parent().unwrap(); - let mut workspace = Workspace::new(workspace_folder.to_path_buf()); + let workspace = Workspace::new(workspace_folder.to_path_buf()); let workspace_search_path = ModuleSearchPath::new( workspace.root().to_path_buf(), ModuleSearchPathKind::FirstParty, ); + let mut program = Program::new(workspace, vec![workspace_search_path]); - let entry_id = files.intern(entry_point); + let entry_id = program.file_id(entry_point); + program.workspace_mut().open_file(entry_id); - let mut program = Program::new(vec![workspace_search_path], files.clone()); - - workspace.open_file(entry_id); - - let (sender, receiver) = crossbeam_channel::bounded( - std::thread::available_parallelism() - .map(NonZeroUsize::get) - .unwrap_or(50) - .max(4), // TODO: Both these numbers are very arbitrary. Pick sensible defaults. - ); + let (main_loop, main_loop_cancellation_token) = MainLoop::new(); // Listen to Ctrl+C and abort the watch mode. - let abort_sender = Mutex::new(Some(sender.clone())); + let main_loop_cancellation_token = Mutex::new(Some(main_loop_cancellation_token)); ctrlc::set_handler(move || { - let mut lock = abort_sender.lock().unwrap(); + let mut lock = main_loop_cancellation_token.lock().unwrap(); - if let Some(sender) = lock.take() { - sender.send(Message::Exit).unwrap(); + if let Some(token) = lock.take() { + token.stop(); } })?; - // Watch for file changes and re-trigger the analysis. - let file_changes_sender = sender.clone(); + let file_changes_notifier = main_loop.file_changes_notifier(); + // Watch for file changes and re-trigger the analysis. let mut file_watcher = FileWatcher::new( move |changes| { - file_changes_sender - .send(Message::FileChanges(changes)) - .unwrap(); + file_changes_notifier.notify(changes); }, - files.clone(), + program.files().clone(), )?; file_watcher.watch_folder(workspace_folder)?; - let files_to_check = vec![entry_id]; - - // Main loop that runs until the user exits the program - // Runs the analysis for each changed file. Cancels the analysis if a new change is detected. - loop { - let changes = { - tracing::trace!("Main Loop: Tick"); - - // Token to cancel the analysis if a new change is detected. - let run_cancellation_token_source = CancellationSource::new(); - let run_cancellation_token = run_cancellation_token_source.token(); - - // Tracks the number of pending analysis runs. - let pending_analysis = Arc::new(AtomicUsize::new(0)); - - // Take read-only references that are copy and Send. - let program = &program; - let workspace = &workspace; - - let receiver = receiver.clone(); - let started_analysis = pending_analysis.clone(); - - // Orchestration task. Ideally, we would run this on main but we should start it as soon as possible so that - // we avoid scheduling tasks when we already know that we're about to exit or cancel the analysis because of a file change. - // This uses `std::thread::spawn` because we don't want it to run inside of the thread pool - // or this code deadlocks when using a thread pool of the size 1. - let orchestration_handle = std::thread::spawn(move || { - fn consume_pending_messages( - receiver: &crossbeam_channel::Receiver, - mut aggregated_changes: AggregatedChanges, - ) -> NextTickCommand { - loop { - // Consume possibly incoming file change messages before running a new analysis, but don't wait for more than 100ms. - crossbeam_channel::select! { - recv(receiver) -> message => { - match message { - Ok(Message::Exit) => { - return NextTickCommand::Exit; - } - Ok(Message::FileChanges(file_changes)) => { - aggregated_changes.extend(file_changes); - } - - Ok(Message::AnalysisCancelled | Message::AnalysisCompleted(_)) => { - unreachable!( - "All analysis should have been completed at this time" - ); - }, - - Err(_) => { - // There are no more senders, no point in waiting for more messages - break; - } - } - }, - default(std::time::Duration::from_millis(100)) => { - break; - } - } - } - - NextTickCommand::FileChanges(aggregated_changes) - } - - let mut diagnostics = Vec::new(); - let mut aggregated_changes = AggregatedChanges::default(); - - for message in &receiver { - match message { - Message::AnalysisCompleted(file_diagnostics) => { - diagnostics.extend_from_slice(&file_diagnostics); - - if pending_analysis.fetch_sub(1, Ordering::SeqCst) == 1 { - // Analysis completed, print the diagnostics. - dbg!(&diagnostics); - } - } - - Message::AnalysisCancelled => { - if pending_analysis.fetch_sub(1, Ordering::SeqCst) == 1 { - return consume_pending_messages(&receiver, aggregated_changes); - } - } - - Message::Exit => { - run_cancellation_token_source.cancel(); - - // Don't consume any outstanding messages because we're exiting anyway. - return NextTickCommand::Exit; - } - - Message::FileChanges(changes) => { - // Request cancellation, but wait until all analysis tasks have completed to - // avoid stale messages in the next main loop. - run_cancellation_token_source.cancel(); - - aggregated_changes.extend(changes); - - if pending_analysis.load(Ordering::SeqCst) == 0 { - return consume_pending_messages(&receiver, aggregated_changes); - } - } - } - } - - // This can be reached if there's no Ctrl+C and no file watcher handler. - // In that case, assume that we don't run in watch mode and exit. - NextTickCommand::Exit - }); - - // Star the analysis task on the thread pool and wait until they complete. - rayon::scope(|scope| { - for file in &files_to_check { - let cancellation_token = run_cancellation_token.clone(); - if cancellation_token.is_cancelled() { - break; - } - - let sender = sender.clone(); - - started_analysis.fetch_add(1, Ordering::SeqCst); - - // TODO: How do we allow the host to control the number of threads used? - // Or should we just assume that each host implements its own main loop, - // I don't think that's entirely unreasonable but we should avoid - // having different main loops per host AND command (e.g. format vs check vs lint) - scope.spawn(move |_| { - if cancellation_token.is_cancelled() { - tracing::trace!("Exit analysis because cancellation was requested."); - sender.send(Message::AnalysisCancelled).unwrap(); - return; - } - - // TODO schedule the dependencies. - let mut diagnostics = Vec::new(); - - if workspace.is_file_open(*file) { - diagnostics.extend_from_slice(&program.lint_syntax(*file)); - } - - sender - .send(Message::AnalysisCompleted(diagnostics)) - .unwrap(); - }); - } - }); - - // Wait for the orchestration task to complete. This either returns the file changes - // or instructs the main loop to exit. - match orchestration_handle.join().unwrap() { - NextTickCommand::FileChanges(changes) => changes, - NextTickCommand::Exit => { - break; - } - } - }; - - // We have a mutable reference here and can perform all necessary invalidations. - program.apply_changes(changes.iter()); - } + main_loop.run(&mut program); let source_jar: &SourceJar = program.jar(); @@ -263,10 +90,259 @@ fn main() -> anyhow::Result<()> { Ok(()) } -enum Message { - AnalysisCompleted(Vec), - AnalysisCancelled, +struct MainLoop { + orchestrator_sender: crossbeam_channel::Sender, + main_loop_receiver: crossbeam_channel::Receiver, +} + +impl MainLoop { + fn new() -> (Self, MainLoopCancellationToken) { + let (orchestrator_sender, orchestrator_receiver) = crossbeam_channel::bounded(1); + let (main_loop_sender, main_loop_receiver) = crossbeam_channel::bounded(1); + + let mut orchestrator = Orchestrator { + pending_analysis: None, + receiver: orchestrator_receiver, + sender: main_loop_sender.clone(), + aggregated_changes: AggregatedChanges::default(), + }; + + std::thread::spawn(move || { + orchestrator.run(); + }); + + ( + Self { + orchestrator_sender, + main_loop_receiver, + }, + MainLoopCancellationToken { + sender: main_loop_sender, + }, + ) + } + + fn file_changes_notifier(&self) -> FileChangesNotifier { + FileChangesNotifier { + sender: self.orchestrator_sender.clone(), + } + } + + fn run(self, program: &mut Program) { + self.orchestrator_sender + .send(OrchestratorMessage::Run) + .unwrap(); + + for message in &self.main_loop_receiver { + tracing::trace!("Main Loop: Tick"); + + match message { + MainLoopMessage::CheckProgram => { + // Remove mutability from program. + let program = &*program; + let run_cancellation_token_source = CancellationTokenSource::new(); + let run_cancellation_token = run_cancellation_token_source.token(); + let sender = &self.orchestrator_sender; + + sender + .send(OrchestratorMessage::CheckProgramStarted { + cancellation_token: run_cancellation_token_source, + }) + .unwrap(); + + rayon::in_place_scope(|scope| { + let scheduler = RayonCheckScheduler { program, scope }; + + let result = program.check(&scheduler, run_cancellation_token); + match result { + Ok(result) => sender + .send(OrchestratorMessage::CheckProgramCompleted(result)) + .unwrap(), + Err(CheckError::Cancelled) => sender + .send(OrchestratorMessage::CheckProgramCancelled) + .unwrap(), + } + }); + } + MainLoopMessage::ApplyChanges(changes) => { + program.apply_changes(changes.iter()); + } + MainLoopMessage::CheckCompleted(diagnostics) => { + dbg!(diagnostics); + } + MainLoopMessage::Exit => { + return; + } + } + } + } +} + +impl Drop for MainLoop { + fn drop(&mut self) { + self.orchestrator_sender + .send(OrchestratorMessage::Shutdown) + .unwrap(); + } +} + +#[derive(Debug, Clone)] +struct FileChangesNotifier { + sender: crossbeam_channel::Sender, +} + +impl FileChangesNotifier { + fn notify(&self, changes: Vec) { + self.sender + .send(OrchestratorMessage::FileChanges(changes)) + .unwrap(); + } +} + +#[derive(Debug)] +struct MainLoopCancellationToken { + sender: crossbeam_channel::Sender, +} + +impl MainLoopCancellationToken { + fn stop(self) { + self.sender.send(MainLoopMessage::Exit).unwrap(); + } +} + +struct Orchestrator { + aggregated_changes: AggregatedChanges, + pending_analysis: Option, + + /// Sends messages to the main loop. + sender: crossbeam_channel::Sender, + /// Receives messages from the main loop. + receiver: crossbeam_channel::Receiver, +} + +impl Orchestrator { + fn run(&mut self) { + while let Ok(message) = self.receiver.recv() { + match message { + OrchestratorMessage::Run => { + self.pending_analysis = None; + self.sender.send(MainLoopMessage::CheckProgram).unwrap(); + } + + OrchestratorMessage::CheckProgramStarted { cancellation_token } => { + debug_assert!(self.pending_analysis.is_none()); + + self.pending_analysis = Some(PendingAnalysisState { cancellation_token }); + } + + OrchestratorMessage::CheckProgramCompleted(diagnostics) => { + self.pending_analysis + .take() + .expect("Expected a pending analysis."); + + self.sender + .send(MainLoopMessage::CheckCompleted(diagnostics)) + .unwrap(); + } + + OrchestratorMessage::CheckProgramCancelled => { + self.pending_analysis + .take() + .expect("Expected a pending analysis."); + + self.debounce_changes(); + } + + OrchestratorMessage::FileChanges(changes) => { + // Request cancellation, but wait until all analysis tasks have completed to + // avoid stale messages in the next main loop. + let pending = if let Some(pending_state) = self.pending_analysis.as_ref() { + pending_state.cancellation_token.cancel(); + true + } else { + false + }; + + self.aggregated_changes.extend(changes); + + // If there are no pending analysis tasks, apply the file changes. Otherwise + // keep running until all file checks have completed. + if !pending { + self.debounce_changes(); + } + } + OrchestratorMessage::Shutdown => { + return self.shutdown(); + } + } + } + } + + fn debounce_changes(&mut self) { + debug_assert!(self.pending_analysis.is_none()); + + loop { + // Consume possibly incoming file change messages before running a new analysis, but don't wait for more than 100ms. + crossbeam_channel::select! { + recv(self.receiver) -> message => { + match message { + Ok(OrchestratorMessage::Shutdown) => { + return self.shutdown(); + } + Ok(OrchestratorMessage::FileChanges(file_changes)) => { + self.aggregated_changes.extend(file_changes); + } + + Ok(OrchestratorMessage::CheckProgramStarted {..}| OrchestratorMessage::CheckProgramCompleted(_) | OrchestratorMessage::CheckProgramCancelled) => unreachable!("No program check should be running while debouncing changes."), + Ok(OrchestratorMessage::Run) => unreachable!("The orchestrator is already running."), + + Err(_) => { + // There are no more senders, no point in waiting for more messages + return; + } + } + }, + default(std::time::Duration::from_millis(100)) => { + // No more file changes after 100 ms, send the changes and schedule a new analysis + self.sender.send(MainLoopMessage::ApplyChanges(std::mem::take(&mut self.aggregated_changes))).unwrap(); + self.sender.send(MainLoopMessage::CheckProgram).unwrap(); + return; + } + } + } + } + + #[allow(clippy::unused_self)] + fn shutdown(&self) { + tracing::trace!("Shutting down orchestrator."); + } +} + +#[derive(Debug)] +struct PendingAnalysisState { + cancellation_token: CancellationTokenSource, +} + +/// Message sent from the orchestrator to the main loop. +#[derive(Debug)] +enum MainLoopMessage { + CheckProgram, + CheckCompleted(Vec), + ApplyChanges(AggregatedChanges), Exit, +} + +#[derive(Debug)] +enum OrchestratorMessage { + Run, + Shutdown, + + CheckProgramStarted { + cancellation_token: CancellationTokenSource, + }, + CheckProgramCompleted(Vec), + CheckProgramCancelled, + FileChanges(Vec), } @@ -340,13 +416,6 @@ impl AggregatedChanges { } } -enum NextTickCommand { - /// Exit the main loop in the next tick - Exit, - /// Apply the given changes in the next main loop tick. - FileChanges(AggregatedChanges), -} - fn setup_tracing() { let subscriber = Registry::default().with( tracing_tree::HierarchicalLayer::default() diff --git a/crates/red_knot/src/module.rs b/crates/red_knot/src/module.rs index 5ce422f40e..62f48e5330 100644 --- a/crates/red_knot/src/module.rs +++ b/crates/red_knot/src/module.rs @@ -164,7 +164,7 @@ pub struct ModuleData { /// Resolves a module name to a module id /// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query. /// For this to work with salsa, it would be necessary to intern all `ModuleName`s. -#[tracing::instrument(level = "trace", skip(db))] +#[tracing::instrument(level = "debug", skip(db))] pub fn resolve_module(db: &Db, name: ModuleName) -> Option where Db: SemanticDb + HasJar, diff --git a/crates/red_knot/src/parse.rs b/crates/red_knot/src/parse.rs index 641181fb93..20df9f4c8c 100644 --- a/crates/red_knot/src/parse.rs +++ b/crates/red_knot/src/parse.rs @@ -63,7 +63,7 @@ impl Parsed { } } -#[tracing::instrument(level = "trace", skip(db))] +#[tracing::instrument(level = "debug", skip(db))] pub(crate) fn parse(db: &Db, file_id: FileId) -> Parsed where Db: SourceDb + HasJar, diff --git a/crates/red_knot/src/program/check.rs b/crates/red_knot/src/program/check.rs new file mode 100644 index 0000000000..1d56caa5bf --- /dev/null +++ b/crates/red_knot/src/program/check.rs @@ -0,0 +1,281 @@ +use crate::cancellation::CancellationToken; +use crate::db::SourceDb; +use crate::files::FileId; +use crate::lint::Diagnostics; +use crate::program::Program; +use rayon::max_num_threads; +use rustc_hash::FxHashSet; +use std::num::NonZeroUsize; + +impl Program { + /// Checks all open files in the workspace and its dependencies. + #[tracing::instrument(level = "debug", skip_all)] + pub fn check( + &self, + scheduler: &dyn CheckScheduler, + cancellation_token: CancellationToken, + ) -> Result, CheckError> { + let check_loop = CheckFilesLoop::new(scheduler, cancellation_token); + + check_loop.run(self.workspace().open_files.iter().copied()) + } + + /// Checks a single file and its dependencies. + #[tracing::instrument(level = "debug", skip(self, scheduler, cancellation_token))] + pub fn check_file( + &self, + file: FileId, + scheduler: &dyn CheckScheduler, + cancellation_token: CancellationToken, + ) -> Result, CheckError> { + let check_loop = CheckFilesLoop::new(scheduler, cancellation_token); + + check_loop.run([file].into_iter()) + } + + #[tracing::instrument(level = "debug", skip(self, context))] + fn do_check_file( + &self, + file: FileId, + context: &CheckContext, + ) -> Result { + context.cancelled_ok()?; + + // TODO schedule the dependencies. + let mut diagnostics = Vec::new(); + + if self.workspace().is_file_open(file) { + diagnostics.extend_from_slice(&self.lint_syntax(file)); + } + + Ok(Diagnostics::from(diagnostics)) + } +} + +/// Schedules checks for files. +pub trait CheckScheduler { + /// Schedules a check for a file. + /// + /// The check can either be run immediately on the current thread or the check can be queued + /// in a thread pool and ran asynchronously. + /// + /// The order in which scheduled checks are executed is not guaranteed. + /// + /// The implementation should call [`CheckFileTask::run`] to execute the check. + fn check_file(&self, file_task: CheckFileTask); + + /// The maximum number of checks that can be run concurrently. + /// + /// Returns `None` if the checks run on the current thread (no concurrency). + fn max_concurrency(&self) -> Option; +} + +/// Scheduler that runs checks on a rayon thread pool. +pub struct RayonCheckScheduler<'program, 'scope_ref, 'scope> { + pub program: &'program Program, + pub scope: &'scope_ref rayon::Scope<'scope>, +} + +impl<'program, 'scope_ref, 'scope> RayonCheckScheduler<'program, 'scope_ref, 'scope> { + pub fn new(program: &'program Program, scope: &'scope_ref rayon::Scope<'scope>) -> Self { + Self { program, scope } + } +} + +impl<'program, 'scope_ref, 'scope> CheckScheduler + for RayonCheckScheduler<'program, 'scope_ref, 'scope> +where + 'program: 'scope, +{ + fn check_file(&self, check_file_task: CheckFileTask) { + let child_span = + tracing::trace_span!("check_file", file_id = check_file_task.file_id.as_u32()); + let program = self.program; + + self.scope + .spawn(move |_| child_span.in_scope(|| check_file_task.run(program))); + } + + fn max_concurrency(&self) -> Option { + Some(NonZeroUsize::new(max_num_threads()).unwrap_or(NonZeroUsize::MIN)) + } +} + +/// Scheduler that runs all checks on the current thread. +pub struct SameThreadCheckScheduler<'a> { + program: &'a Program, +} + +impl<'a> SameThreadCheckScheduler<'a> { + pub fn new(program: &'a Program) -> Self { + Self { program } + } +} + +impl CheckScheduler for SameThreadCheckScheduler<'_> { + fn check_file(&self, task: CheckFileTask) { + task.run(self.program); + } + + fn max_concurrency(&self) -> Option { + None + } +} + +#[derive(Debug, Clone)] +pub enum CheckError { + Cancelled, +} + +#[derive(Debug)] +pub struct CheckFileTask { + file_id: FileId, + context: CheckContext, +} + +impl CheckFileTask { + /// Runs the check and communicates the result to the orchestrator. + pub fn run(self, program: &Program) { + match program.do_check_file(self.file_id, &self.context) { + Ok(diagnostics) => self + .context + .sender + .send(CheckFileMessage::Completed(diagnostics)) + .unwrap(), + Err(CheckError::Cancelled) => self + .context + .sender + .send(CheckFileMessage::Cancelled) + .unwrap(), + } + } +} + +#[derive(Clone, Debug)] +struct CheckContext { + cancellation_token: CancellationToken, + sender: crossbeam_channel::Sender, +} + +impl CheckContext { + fn new( + cancellation_token: CancellationToken, + sender: crossbeam_channel::Sender, + ) -> Self { + Self { + cancellation_token, + sender, + } + } + + /// Queues a new file for checking using the [`CheckScheduler`]. + #[allow(unused)] + fn schedule_check_file(&self, file_id: FileId) { + self.sender.send(CheckFileMessage::Queue(file_id)).unwrap(); + } + + /// Returns `true` if the check has been cancelled. + fn is_cancelled(&self) -> bool { + self.cancellation_token.is_cancelled() + } + + fn cancelled_ok(&self) -> Result<(), CheckError> { + if self.is_cancelled() { + Err(CheckError::Cancelled) + } else { + Ok(()) + } + } +} + +struct CheckFilesLoop<'a> { + scheduler: &'a dyn CheckScheduler, + cancellation_token: CancellationToken, + pending: usize, + queued_files: FxHashSet, +} + +impl<'a> CheckFilesLoop<'a> { + fn new(scheduler: &'a dyn CheckScheduler, cancellation_token: CancellationToken) -> Self { + Self { + scheduler, + cancellation_token, + + queued_files: FxHashSet::default(), + pending: 0, + } + } + + fn run(mut self, files: impl Iterator) -> Result, CheckError> { + let (sender, receiver) = if let Some(max_concurrency) = self.scheduler.max_concurrency() { + crossbeam_channel::bounded(max_concurrency.get()) + } else { + // The checks run on the current thread. That means it is necessary to store all messages + // or we risk deadlocking when the main loop never gets a chance to read the messages. + crossbeam_channel::unbounded() + }; + + let context = CheckContext::new(self.cancellation_token.clone(), sender.clone()); + + for file in files { + self.queue_file(file, context.clone())?; + } + + self.run_impl(receiver, &context) + } + + fn run_impl( + mut self, + receiver: crossbeam_channel::Receiver, + context: &CheckContext, + ) -> Result, CheckError> { + if self.cancellation_token.is_cancelled() { + return Err(CheckError::Cancelled); + } + + let mut result = Vec::default(); + + for message in receiver { + match message { + CheckFileMessage::Completed(diagnostics) => { + result.extend_from_slice(&diagnostics); + + self.pending -= 1; + + if self.pending == 0 { + break; + } + } + CheckFileMessage::Queue(id) => { + self.queue_file(id, context.clone())?; + } + CheckFileMessage::Cancelled => { + return Err(CheckError::Cancelled); + } + } + } + + Ok(result) + } + + fn queue_file(&mut self, file_id: FileId, context: CheckContext) -> Result<(), CheckError> { + if context.is_cancelled() { + return Err(CheckError::Cancelled); + } + + if self.queued_files.insert(file_id) { + self.pending += 1; + + self.scheduler + .check_file(CheckFileTask { file_id, context }); + } + + Ok(()) + } +} + +enum CheckFileMessage { + Completed(Diagnostics), + Queue(FileId), + Cancelled, +} diff --git a/crates/red_knot/src/program/mod.rs b/crates/red_knot/src/program/mod.rs index e9055938dc..f0dac73014 100644 --- a/crates/red_knot/src/program/mod.rs +++ b/crates/red_knot/src/program/mod.rs @@ -1,3 +1,5 @@ +pub mod check; + use std::path::Path; use std::sync::Arc; @@ -12,16 +14,18 @@ use crate::parse::{parse, Parsed, ParsedStorage}; use crate::source::{source_text, Source, SourceStorage}; use crate::symbols::{symbol_table, SymbolId, SymbolTable, SymbolTablesStorage}; use crate::types::{infer_symbol_type, Type, TypeStore}; +use crate::Workspace; #[derive(Debug)] pub struct Program { files: Files, source: SourceJar, semantic: SemanticJar, + workspace: Workspace, } impl Program { - pub fn new(module_search_paths: Vec, files: Files) -> Self { + pub fn new(workspace: Workspace, module_search_paths: Vec) -> Self { Self { source: SourceJar { sources: SourceStorage::default(), @@ -33,7 +37,8 @@ impl Program { symbol_tables: SymbolTablesStorage::default(), type_store: TypeStore::default(), }, - files, + files: Files::default(), + workspace, } } @@ -53,6 +58,18 @@ impl Program { self.semantic.type_store.remove_module(change.id); } } + + pub fn files(&self) -> &Files { + &self.files + } + + pub fn workspace(&self) -> &Workspace { + &self.workspace + } + + pub fn workspace_mut(&mut self) -> &mut Workspace { + &mut self.workspace + } } impl SourceDb for Program { diff --git a/crates/red_knot/src/source.rs b/crates/red_knot/src/source.rs index 7dd6ed9285..08ad2d8aba 100644 --- a/crates/red_knot/src/source.rs +++ b/crates/red_knot/src/source.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use crate::files::FileId; -#[tracing::instrument(level = "trace", skip(db))] +#[tracing::instrument(level = "debug", skip(db))] pub(crate) fn source_text(db: &Db, file_id: FileId) -> Source where Db: SourceDb + HasJar, @@ -15,8 +15,6 @@ where let sources = &db.jar().sources; sources.get(&file_id, |file_id| { - tracing::trace!("Reading source text for file_id={:?}.", file_id); - let path = db.file_path(*file_id); let source_text = std::fs::read_to_string(&path).unwrap_or_else(|err| { diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 96fa6b120f..0277e51d3b 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -19,7 +19,7 @@ use crate::files::FileId; use crate::Name; #[allow(unreachable_pub)] -#[tracing::instrument(level = "trace", skip(db))] +#[tracing::instrument(level = "debug", skip(db))] pub fn symbol_table(db: &Db, file_id: FileId) -> Arc where Db: SemanticDb + HasJar,