From 56684e4c24eb8e80616d5970ab590c637c2ed016 Mon Sep 17 00:00:00 2001 From: konsti Date: Mon, 3 Feb 2025 16:41:17 +0100 Subject: [PATCH] Respect concurrency limits in parallel index fetch (#11182) With the parallel simple index fetching, we would only acquire one download concurrency token, meaning that we could in the worst case make times the number of indexes more requests than the user requested limit. We fix this by passing the semaphore down to the simple API method. --- crates/uv-client/src/registry_client.rs | 5 ++++- .../uv-distribution/src/distribution_database.rs | 14 ++++++++++++++ crates/uv-publish/src/lib.rs | 13 +++++++++++-- crates/uv-resolver/src/resolver/provider.rs | 4 +++- crates/uv/src/commands/pip/latest.rs | 8 +++++++- crates/uv/src/commands/pip/list.rs | 6 +++++- crates/uv/src/commands/pip/tree.rs | 7 ++++++- crates/uv/src/commands/project/tree.rs | 8 ++++++-- crates/uv/src/commands/publish.rs | 8 +++++++- 9 files changed, 63 insertions(+), 10 deletions(-) diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 55eabf85f..02a0a9949 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -10,6 +10,7 @@ use http::HeaderMap; use itertools::Either; use reqwest::{Client, Response, StatusCode}; use reqwest_middleware::ClientWithMiddleware; +use tokio::sync::Semaphore; use tracing::{info_span, instrument, trace, warn, Instrument}; use url::Url; @@ -235,6 +236,7 @@ impl RegistryClient { package_name: &PackageName, index: Option<&'index IndexUrl>, capabilities: &IndexCapabilities, + download_concurrency: &Semaphore, ) -> Result)>, Error> { let indexes = if let Some(index) = index { Either::Left(std::iter::once(index)) @@ -253,6 +255,7 @@ impl RegistryClient { // If we're searching for the first index that contains the package, fetch serially. IndexStrategy::FirstIndex => { for index in it { + let _permit = download_concurrency.acquire().await; if let Some(metadata) = self .simple_single_index(package_name, index, capabilities) .await? @@ -265,9 +268,9 @@ impl RegistryClient { // Otherwise, fetch concurrently. IndexStrategy::UnsafeBestMatch | IndexStrategy::UnsafeFirstMatch => { - // TODO(charlie): Respect concurrency limits. results = futures::stream::iter(it) .map(|index| async move { + let _permit = download_concurrency.acquire().await; let metadata = self .simple_single_index(package_name, index, capabilities) .await?; diff --git a/crates/uv-distribution/src/distribution_database.rs b/crates/uv-distribution/src/distribution_database.rs index 9a74f9d2a..14fb52dba 100644 --- a/crates/uv-distribution/src/distribution_database.rs +++ b/crates/uv-distribution/src/distribution_database.rs @@ -992,6 +992,20 @@ impl<'a> ManagedClient<'a> { let _permit = self.control.acquire().await.unwrap(); f(self.unmanaged).await } + + /// Perform a request using a client that internally manages the concurrency limit. + /// + /// The callback is passed the client and a semaphore. It must acquire the semaphore before + /// any request through the client and drop it after. + /// + /// This method serves as an escape hatch for functions that may want to send multiple requests + /// in parallel. + pub async fn manual(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T + where + F: Future, + { + f(self.unmanaged, &self.control).await + } } /// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present. diff --git a/crates/uv-publish/src/lib.rs b/crates/uv-publish/src/lib.rs index 19c40abda..4d65fc11d 100644 --- a/crates/uv-publish/src/lib.rs +++ b/crates/uv-publish/src/lib.rs @@ -20,6 +20,7 @@ use std::time::{Duration, SystemTime}; use std::{env, fmt, io}; use thiserror::Error; use tokio::io::{AsyncReadExt, BufReader}; +use tokio::sync::Semaphore; use tokio_util::io::ReaderStream; use tracing::{debug, enabled, trace, warn, Level}; use url::Url; @@ -369,6 +370,7 @@ pub async fn upload( username: Option<&str>, password: Option<&str>, check_url_client: Option<&CheckUrlClient<'_>>, + download_concurrency: &Semaphore, reporter: Arc, ) -> Result { let form_metadata = form_metadata(file, filename) @@ -428,7 +430,8 @@ pub async fn upload( PublishSendError::Status(..) | PublishSendError::StatusNoBody(..) ) { if let Some(check_url_client) = &check_url_client { - if check_url(check_url_client, file, filename).await? { + if check_url(check_url_client, file, filename, download_concurrency).await? + { // There was a raced upload of the same file, so even though our upload failed, // the right file now exists in the registry. return Ok(false); @@ -450,6 +453,7 @@ pub async fn check_url( check_url_client: &CheckUrlClient<'_>, file: &Path, filename: &DistFilename, + download_concurrency: &Semaphore, ) -> Result { let CheckUrlClient { index_url, @@ -470,7 +474,12 @@ pub async fn check_url( debug!("Checking for {filename} in the registry"); let response = match registry_client - .simple(filename.name(), Some(index_url), index_capabilities) + .simple( + filename.name(), + Some(index_url), + index_capabilities, + download_concurrency, + ) .await { Ok(response) => response, diff --git a/crates/uv-resolver/src/resolver/provider.rs b/crates/uv-resolver/src/resolver/provider.rs index dc9b38e08..0c076b858 100644 --- a/crates/uv-resolver/src/resolver/provider.rs +++ b/crates/uv-resolver/src/resolver/provider.rs @@ -155,7 +155,9 @@ impl ResolverProvider for DefaultResolverProvider<'_, Con let result = self .fetcher .client() - .managed(|client| client.simple(package_name, index, self.capabilities)) + .manual(|client, semaphore| { + client.simple(package_name, index, self.capabilities, semaphore) + }) .await; match result { diff --git a/crates/uv/src/commands/pip/latest.rs b/crates/uv/src/commands/pip/latest.rs index abc028a00..12f563629 100644 --- a/crates/uv/src/commands/pip/latest.rs +++ b/crates/uv/src/commands/pip/latest.rs @@ -1,3 +1,4 @@ +use tokio::sync::Semaphore; use tracing::debug; use uv_client::{RegistryClient, VersionFiles}; use uv_distribution_filename::DistFilename; @@ -27,10 +28,15 @@ impl LatestClient<'_> { &self, package: &PackageName, index: Option<&IndexUrl>, + download_concurrency: &Semaphore, ) -> anyhow::Result, uv_client::Error> { debug!("Fetching latest version of: `{package}`"); - let archives = match self.client.simple(package, index, self.capabilities).await { + let archives = match self + .client + .simple(package, index, self.capabilities, download_concurrency) + .await + { Ok(archives) => archives, Err(err) => { return match err.into_kind() { diff --git a/crates/uv/src/commands/pip/list.rs b/crates/uv/src/commands/pip/list.rs index 6ec866404..d9127ecae 100644 --- a/crates/uv/src/commands/pip/list.rs +++ b/crates/uv/src/commands/pip/list.rs @@ -8,6 +8,7 @@ use itertools::Itertools; use owo_colors::OwoColorize; use rustc_hash::FxHashMap; use serde::Serialize; +use tokio::sync::Semaphore; use unicode_width::UnicodeWidthStr; use uv_cache::{Cache, Refresh}; @@ -94,6 +95,7 @@ pub(crate) async fn pip_list( .markers(environment.interpreter().markers()) .platform(environment.interpreter().platform()) .build(); + let download_concurrency = Semaphore::new(concurrency.downloads); // Determine the platform tags. let interpreter = environment.interpreter(); @@ -116,7 +118,9 @@ pub(crate) async fn pip_list( // Fetch the latest version for each package. let mut fetches = futures::stream::iter(&results) .map(|dist| async { - let latest = client.find_latest(dist.name(), None).await?; + let latest = client + .find_latest(dist.name(), None, &download_concurrency) + .await?; Ok::<(&PackageName, Option), uv_client::Error>((dist.name(), latest)) }) .buffer_unordered(concurrency.downloads); diff --git a/crates/uv/src/commands/pip/tree.rs b/crates/uv/src/commands/pip/tree.rs index 077dcf6f5..da0266719 100644 --- a/crates/uv/src/commands/pip/tree.rs +++ b/crates/uv/src/commands/pip/tree.rs @@ -8,6 +8,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::EdgeRef; use petgraph::Direction; use rustc_hash::{FxHashMap, FxHashSet}; +use tokio::sync::Semaphore; use uv_cache::{Cache, Refresh}; use uv_cache_info::Timestamp; @@ -95,6 +96,7 @@ pub(crate) async fn pip_tree( .markers(environment.interpreter().markers()) .platform(environment.interpreter().platform()) .build(); + let download_concurrency = Semaphore::new(concurrency.downloads); // Determine the platform tags. let interpreter = environment.interpreter(); @@ -117,7 +119,10 @@ pub(crate) async fn pip_tree( // Fetch the latest version for each package. let mut fetches = futures::stream::iter(&packages) .map(|(name, ..)| async { - let Some(filename) = client.find_latest(name, None).await? else { + let Some(filename) = client + .find_latest(name, None, &download_concurrency) + .await? + else { return Ok(None); }; Ok::, uv_client::Error>(Some((*name, filename.into_version()))) diff --git a/crates/uv/src/commands/project/tree.rs b/crates/uv/src/commands/project/tree.rs index 83b4234df..14d07722e 100644 --- a/crates/uv/src/commands/project/tree.rs +++ b/crates/uv/src/commands/project/tree.rs @@ -3,7 +3,7 @@ use std::path::Path; use anstream::print; use anyhow::{Error, Result}; use futures::StreamExt; - +use tokio::sync::Semaphore; use uv_cache::{Cache, Refresh}; use uv_cache_info::Timestamp; use uv_client::{Connectivity, RegistryClientBuilder}; @@ -225,6 +225,7 @@ pub(crate) async fn tree( .keyring(*keyring_provider) .allow_insecure_host(allow_insecure_host.to_vec()) .build(); + let download_concurrency = Semaphore::new(concurrency.downloads); // Initialize the client to fetch the latest version of each package. let client = LatestClient { @@ -239,9 +240,12 @@ pub(crate) async fn tree( let reporter = LatestVersionReporter::from(printer).with_length(packages.len() as u64); // Fetch the latest version for each package. + let download_concurrency = &download_concurrency; let mut fetches = futures::stream::iter(packages) .map(|(package, index)| async move { - let Some(filename) = client.find_latest(package.name(), Some(&index)).await? + let Some(filename) = client + .find_latest(package.name(), Some(&index), download_concurrency) + .await? else { return Ok(None); }; diff --git a/crates/uv/src/commands/publish.rs b/crates/uv/src/commands/publish.rs index cef9c7833..898348c64 100644 --- a/crates/uv/src/commands/publish.rs +++ b/crates/uv/src/commands/publish.rs @@ -8,6 +8,7 @@ use std::fmt::Write; use std::iter; use std::sync::Arc; use std::time::Duration; +use tokio::sync::Semaphore; use tracing::{debug, info}; use url::Url; use uv_cache::Cache; @@ -69,6 +70,8 @@ pub(crate) async fn publish( let oidc_client = BaseClientBuilder::new() .auth_integration(AuthIntegration::NoAuthMiddleware) .wrap_existing(&upload_client); + // We're only checking a single URL and one at a time, so 1 permit is sufficient + let download_concurrency = Arc::new(Semaphore::new(1)); let (publish_url, username, password) = gather_credentials( publish_url, @@ -110,7 +113,9 @@ pub(crate) async fn publish( for (file, raw_filename, filename) in files { if let Some(check_url_client) = &check_url_client { - if uv_publish::check_url(check_url_client, &file, &filename).await? { + if uv_publish::check_url(check_url_client, &file, &filename, &download_concurrency) + .await? + { writeln!(printer.stderr(), "File {filename} already exists, skipping")?; continue; } @@ -134,6 +139,7 @@ pub(crate) async fn publish( username.as_deref(), password.as_deref(), check_url_client.as_ref(), + &download_concurrency, // Needs to be an `Arc` because the reqwest `Body` static lifetime requirement Arc::new(reporter), )