From 550b8be5520e9dd780ddd3d58f833f1379cdeee5 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Mon, 12 May 2025 15:07:55 -0400 Subject: [PATCH] Avoid initializing progress bars early (#18049) ## Summary Resolves https://github.com/astral-sh/ty/issues/324. --- crates/ruff_benchmark/benches/ty.rs | 10 ++++----- crates/ty/src/lib.rs | 34 ++++++++++++++++------------- crates/ty/tests/file_watching.rs | 12 +++------- crates/ty_project/src/db.rs | 17 ++++++++++++--- crates/ty_project/src/lib.rs | 15 ++++++++----- crates/ty_wasm/src/lib.rs | 4 ++-- 6 files changed, 53 insertions(+), 39 deletions(-) diff --git a/crates/ruff_benchmark/benches/ty.rs b/crates/ruff_benchmark/benches/ty.rs index 7a0fbbf8a1..3bb28fe0ce 100644 --- a/crates/ruff_benchmark/benches/ty.rs +++ b/crates/ruff_benchmark/benches/ty.rs @@ -16,7 +16,7 @@ use ruff_python_ast::PythonVersion; use ty_project::metadata::options::{EnvironmentOptions, Options}; use ty_project::metadata::value::RangedValue; use ty_project::watch::{ChangeEvent, ChangedKind}; -use ty_project::{Db, DummyReporter, ProjectDatabase, ProjectMetadata}; +use ty_project::{Db, ProjectDatabase, ProjectMetadata}; struct Case { db: ProjectDatabase, @@ -164,7 +164,7 @@ fn benchmark_incremental(criterion: &mut Criterion) { fn setup() -> Case { let case = setup_tomllib_case(); - let result: Vec<_> = case.db.check(&DummyReporter).unwrap(); + let result: Vec<_> = case.db.check().unwrap(); assert_diagnostics(&case.db, &result, EXPECTED_TOMLLIB_DIAGNOSTICS); @@ -192,7 +192,7 @@ fn benchmark_incremental(criterion: &mut Criterion) { None, ); - let result = db.check(&DummyReporter).unwrap(); + let result = db.check().unwrap(); assert_eq!(result.len(), EXPECTED_TOMLLIB_DIAGNOSTICS.len()); } @@ -212,7 +212,7 @@ fn benchmark_cold(criterion: &mut Criterion) { setup_tomllib_case, |case| { let Case { db, .. } = case; - let result: Vec<_> = db.check(&DummyReporter).unwrap(); + let result: Vec<_> = db.check().unwrap(); assert_diagnostics(db, &result, EXPECTED_TOMLLIB_DIAGNOSTICS); }, @@ -326,7 +326,7 @@ fn benchmark_many_string_assignments(criterion: &mut Criterion) { }, |case| { let Case { db, .. } = case; - let result = db.check(&DummyReporter).unwrap(); + let result = db.check().unwrap(); assert_eq!(result.len(), 0); }, BatchSize::SmallInput, diff --git a/crates/ty/src/lib.rs b/crates/ty/src/lib.rs index 584f9acb1b..a0b2e1037e 100644 --- a/crates/ty/src/lib.rs +++ b/crates/ty/src/lib.rs @@ -216,7 +216,10 @@ impl MainLoop { self.run_with_progress::(db) } - fn run_with_progress(mut self, db: &mut ProjectDatabase) -> Result { + fn run_with_progress(mut self, db: &mut ProjectDatabase) -> Result + where + R: Reporter + Default + 'static, + { self.sender.send(MainLoopMessage::CheckWorkspace).unwrap(); let result = self.main_loop::(db); @@ -226,7 +229,10 @@ impl MainLoop { result } - fn main_loop(&mut self, db: &mut ProjectDatabase) -> Result { + fn main_loop(&mut self, db: &mut ProjectDatabase) -> Result + where + R: Reporter + Default + 'static, + { // Schedule the first check. tracing::debug!("Starting main loop"); @@ -237,12 +243,12 @@ impl MainLoop { MainLoopMessage::CheckWorkspace => { let db = db.clone(); let sender = self.sender.clone(); - let reporter = R::default(); + let mut reporter = R::default(); // Spawn a new task that checks the project. This needs to be done in a separate thread // to prevent blocking the main loop here. rayon::spawn(move || { - match db.check(&reporter) { + match db.check_with_reporter(&mut reporter) { Ok(result) => { // Send the result back to the main loop for printing. sender @@ -353,11 +359,12 @@ impl MainLoop { } /// A progress reporter for `ty check`. -struct IndicatifReporter(indicatif::ProgressBar); +#[derive(Default)] +struct IndicatifReporter(Option); -impl Default for IndicatifReporter { - fn default() -> IndicatifReporter { - let progress = indicatif::ProgressBar::new(0); +impl ty_project::Reporter for IndicatifReporter { + fn set_files(&mut self, files: usize) { + let progress = indicatif::ProgressBar::new(files as u64); progress.set_style( indicatif::ProgressStyle::with_template( "{msg:8.dim} {bar:60.green/dim} {pos}/{len} files", @@ -366,17 +373,14 @@ impl Default for IndicatifReporter { .progress_chars("--"), ); progress.set_message("Checking"); - IndicatifReporter(progress) - } -} -impl ty_project::Reporter for IndicatifReporter { - fn set_files(&self, files: usize) { - self.0.set_length(files as u64); + self.0 = Some(progress); } fn report_file(&self, _file: &ruff_db::files::File) { - self.0.inc(1); + if let Some(ref progress_bar) = self.0 { + progress_bar.inc(1); + } } } diff --git a/crates/ty/tests/file_watching.rs b/crates/ty/tests/file_watching.rs index 56ff5e2bfc..c8205c9251 100644 --- a/crates/ty/tests/file_watching.rs +++ b/crates/ty/tests/file_watching.rs @@ -14,7 +14,7 @@ use ty_project::metadata::options::{EnvironmentOptions, Options}; use ty_project::metadata::pyproject::{PyProject, Tool}; use ty_project::metadata::value::{RangedValue, RelativePathBuf}; use ty_project::watch::{directory_watcher, ChangeEvent, ProjectWatcher}; -use ty_project::{Db, DummyReporter, ProjectDatabase, ProjectMetadata}; +use ty_project::{Db, ProjectDatabase, ProjectMetadata}; use ty_python_semantic::{resolve_module, ModuleName, PythonPlatform}; struct TestCase { @@ -1117,10 +1117,7 @@ print(sys.last_exc, os.getegid()) Ok(()) })?; - let diagnostics = case - .db - .check(&DummyReporter) - .context("Failed to check project.")?; + let diagnostics = case.db.check().context("Failed to check project.")?; assert_eq!(diagnostics.len(), 2); assert_eq!( @@ -1145,10 +1142,7 @@ print(sys.last_exc, os.getegid()) }) .expect("Search path settings to be valid"); - let diagnostics = case - .db - .check(&DummyReporter) - .context("Failed to check project.")?; + let diagnostics = case.db.check().context("Failed to check project.")?; assert!(diagnostics.is_empty()); Ok(()) diff --git a/crates/ty_project/src/db.rs b/crates/ty_project/src/db.rs index c15f0c99c4..07ae404e22 100644 --- a/crates/ty_project/src/db.rs +++ b/crates/ty_project/src/db.rs @@ -1,7 +1,7 @@ -use std::panic::RefUnwindSafe; +use std::panic::{AssertUnwindSafe, RefUnwindSafe}; use std::sync::Arc; -use crate::DEFAULT_LINT_REGISTRY; +use crate::{DummyReporter, DEFAULT_LINT_REGISTRY}; use crate::{Project, ProjectMetadata, Reporter}; use ruff_db::diagnostic::Diagnostic; use ruff_db::files::{File, Files}; @@ -68,7 +68,18 @@ impl ProjectDatabase { } /// Checks all open files in the project and its dependencies. - pub fn check(&self, reporter: &impl Reporter) -> Result, Cancelled> { + pub fn check(&self) -> Result, Cancelled> { + let mut reporter = DummyReporter; + let reporter = AssertUnwindSafe(&mut reporter as &mut dyn Reporter); + self.with_db(|db| db.project().check(db, reporter)) + } + + /// Checks all open files in the project and its dependencies, using the given reporter. + pub fn check_with_reporter( + &self, + reporter: &mut dyn Reporter, + ) -> Result, Cancelled> { + let reporter = AssertUnwindSafe(reporter); self.with_db(|db| db.project().check(db, reporter)) } diff --git a/crates/ty_project/src/lib.rs b/crates/ty_project/src/lib.rs index 8385004cf2..4e36f693d5 100644 --- a/crates/ty_project/src/lib.rs +++ b/crates/ty_project/src/lib.rs @@ -18,7 +18,7 @@ use rustc_hash::FxHashSet; use salsa::Durability; use salsa::Setter; use std::backtrace::BacktraceStatus; -use std::panic::{AssertUnwindSafe, RefUnwindSafe, UnwindSafe}; +use std::panic::{AssertUnwindSafe, UnwindSafe}; use std::sync::Arc; use thiserror::Error; use tracing::error; @@ -107,9 +107,9 @@ pub struct Project { } /// A progress reporter. -pub trait Reporter: Default + Send + Sync + RefUnwindSafe + 'static { +pub trait Reporter: Send + Sync { /// Initialize the reporter with the number of files. - fn set_files(&self, files: usize); + fn set_files(&mut self, files: usize); /// Report the completion of a given file. fn report_file(&self, file: &File); @@ -120,7 +120,7 @@ pub trait Reporter: Default + Send + Sync + RefUnwindSafe + 'static { pub struct DummyReporter; impl Reporter for DummyReporter { - fn set_files(&self, _files: usize) {} + fn set_files(&mut self, _files: usize) {} fn report_file(&self, _file: &File) {} } @@ -186,7 +186,11 @@ impl Project { } /// Checks all open files in the project and its dependencies. - pub(crate) fn check(self, db: &ProjectDatabase, reporter: &impl Reporter) -> Vec { + pub(crate) fn check( + self, + db: &ProjectDatabase, + mut reporter: AssertUnwindSafe<&mut dyn Reporter>, + ) -> Vec { let project_span = tracing::debug_span!("Project::check"); let _span = project_span.enter(); @@ -215,6 +219,7 @@ impl Project { let db = db.clone(); let file_diagnostics = &file_diagnostics; let project_span = &project_span; + let reporter = &reporter; rayon::scope(move |scope| { for file in &files { diff --git a/crates/ty_wasm/src/lib.rs b/crates/ty_wasm/src/lib.rs index 2a7cad6f8a..9ed86299e8 100644 --- a/crates/ty_wasm/src/lib.rs +++ b/crates/ty_wasm/src/lib.rs @@ -18,8 +18,8 @@ use ty_ide::{goto_type_definition, hover, inlay_hints, MarkupKind}; use ty_project::metadata::options::Options; use ty_project::metadata::value::ValueSource; use ty_project::watch::{ChangeEvent, ChangedKind, CreatedKind, DeletedKind}; +use ty_project::ProjectMetadata; use ty_project::{Db, ProjectDatabase}; -use ty_project::{DummyReporter, ProjectMetadata}; use ty_python_semantic::Program; use wasm_bindgen::prelude::*; @@ -186,7 +186,7 @@ impl Workspace { /// Checks all open files pub fn check(&self) -> Result, Error> { - let result = self.db.check(&DummyReporter).map_err(into_error)?; + let result = self.db.check().map_err(into_error)?; Ok(result.into_iter().map(Diagnostic::wrap).collect()) }