Respect concurrency limits in parallel index fetch (#11182)

With the parallel simple index fetching, we would only acquire one
download concurrency token, meaning that we could in the worst case make
times the number of indexes more requests than the user requested limit.
We fix this by passing the semaphore down to the simple API method.
This commit is contained in:
konsti 2025-02-03 16:41:17 +01:00 committed by GitHub
parent c54dbcbcc2
commit 56684e4c24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 63 additions and 10 deletions

View File

@ -10,6 +10,7 @@ use http::HeaderMap;
use itertools::Either;
use reqwest::{Client, Response, StatusCode};
use reqwest_middleware::ClientWithMiddleware;
use tokio::sync::Semaphore;
use tracing::{info_span, instrument, trace, warn, Instrument};
use url::Url;
@ -235,6 +236,7 @@ impl RegistryClient {
package_name: &PackageName,
index: Option<&'index IndexUrl>,
capabilities: &IndexCapabilities,
download_concurrency: &Semaphore,
) -> Result<Vec<(&'index IndexUrl, OwnedArchive<SimpleMetadata>)>, Error> {
let indexes = if let Some(index) = index {
Either::Left(std::iter::once(index))
@ -253,6 +255,7 @@ impl RegistryClient {
// If we're searching for the first index that contains the package, fetch serially.
IndexStrategy::FirstIndex => {
for index in it {
let _permit = download_concurrency.acquire().await;
if let Some(metadata) = self
.simple_single_index(package_name, index, capabilities)
.await?
@ -265,9 +268,9 @@ impl RegistryClient {
// Otherwise, fetch concurrently.
IndexStrategy::UnsafeBestMatch | IndexStrategy::UnsafeFirstMatch => {
// TODO(charlie): Respect concurrency limits.
results = futures::stream::iter(it)
.map(|index| async move {
let _permit = download_concurrency.acquire().await;
let metadata = self
.simple_single_index(package_name, index, capabilities)
.await?;

View File

@ -992,6 +992,20 @@ impl<'a> ManagedClient<'a> {
let _permit = self.control.acquire().await.unwrap();
f(self.unmanaged).await
}
/// Perform a request using a client that internally manages the concurrency limit.
///
/// The callback is passed the client and a semaphore. It must acquire the semaphore before
/// any request through the client and drop it after.
///
/// This method serves as an escape hatch for functions that may want to send multiple requests
/// in parallel.
pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
where
F: Future<Output = T>,
{
f(self.unmanaged, &self.control).await
}
}
/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.

View File

@ -20,6 +20,7 @@ use std::time::{Duration, SystemTime};
use std::{env, fmt, io};
use thiserror::Error;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::sync::Semaphore;
use tokio_util::io::ReaderStream;
use tracing::{debug, enabled, trace, warn, Level};
use url::Url;
@ -369,6 +370,7 @@ pub async fn upload(
username: Option<&str>,
password: Option<&str>,
check_url_client: Option<&CheckUrlClient<'_>>,
download_concurrency: &Semaphore,
reporter: Arc<impl Reporter>,
) -> Result<bool, PublishError> {
let form_metadata = form_metadata(file, filename)
@ -428,7 +430,8 @@ pub async fn upload(
PublishSendError::Status(..) | PublishSendError::StatusNoBody(..)
) {
if let Some(check_url_client) = &check_url_client {
if check_url(check_url_client, file, filename).await? {
if check_url(check_url_client, file, filename, download_concurrency).await?
{
// There was a raced upload of the same file, so even though our upload failed,
// the right file now exists in the registry.
return Ok(false);
@ -450,6 +453,7 @@ pub async fn check_url(
check_url_client: &CheckUrlClient<'_>,
file: &Path,
filename: &DistFilename,
download_concurrency: &Semaphore,
) -> Result<bool, PublishError> {
let CheckUrlClient {
index_url,
@ -470,7 +474,12 @@ pub async fn check_url(
debug!("Checking for {filename} in the registry");
let response = match registry_client
.simple(filename.name(), Some(index_url), index_capabilities)
.simple(
filename.name(),
Some(index_url),
index_capabilities,
download_concurrency,
)
.await
{
Ok(response) => response,

View File

@ -155,7 +155,9 @@ impl<Context: BuildContext> ResolverProvider for DefaultResolverProvider<'_, Con
let result = self
.fetcher
.client()
.managed(|client| client.simple(package_name, index, self.capabilities))
.manual(|client, semaphore| {
client.simple(package_name, index, self.capabilities, semaphore)
})
.await;
match result {

View File

@ -1,3 +1,4 @@
use tokio::sync::Semaphore;
use tracing::debug;
use uv_client::{RegistryClient, VersionFiles};
use uv_distribution_filename::DistFilename;
@ -27,10 +28,15 @@ impl LatestClient<'_> {
&self,
package: &PackageName,
index: Option<&IndexUrl>,
download_concurrency: &Semaphore,
) -> anyhow::Result<Option<DistFilename>, uv_client::Error> {
debug!("Fetching latest version of: `{package}`");
let archives = match self.client.simple(package, index, self.capabilities).await {
let archives = match self
.client
.simple(package, index, self.capabilities, download_concurrency)
.await
{
Ok(archives) => archives,
Err(err) => {
return match err.into_kind() {

View File

@ -8,6 +8,7 @@ use itertools::Itertools;
use owo_colors::OwoColorize;
use rustc_hash::FxHashMap;
use serde::Serialize;
use tokio::sync::Semaphore;
use unicode_width::UnicodeWidthStr;
use uv_cache::{Cache, Refresh};
@ -94,6 +95,7 @@ pub(crate) async fn pip_list(
.markers(environment.interpreter().markers())
.platform(environment.interpreter().platform())
.build();
let download_concurrency = Semaphore::new(concurrency.downloads);
// Determine the platform tags.
let interpreter = environment.interpreter();
@ -116,7 +118,9 @@ pub(crate) async fn pip_list(
// Fetch the latest version for each package.
let mut fetches = futures::stream::iter(&results)
.map(|dist| async {
let latest = client.find_latest(dist.name(), None).await?;
let latest = client
.find_latest(dist.name(), None, &download_concurrency)
.await?;
Ok::<(&PackageName, Option<DistFilename>), uv_client::Error>((dist.name(), latest))
})
.buffer_unordered(concurrency.downloads);

View File

@ -8,6 +8,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::EdgeRef;
use petgraph::Direction;
use rustc_hash::{FxHashMap, FxHashSet};
use tokio::sync::Semaphore;
use uv_cache::{Cache, Refresh};
use uv_cache_info::Timestamp;
@ -95,6 +96,7 @@ pub(crate) async fn pip_tree(
.markers(environment.interpreter().markers())
.platform(environment.interpreter().platform())
.build();
let download_concurrency = Semaphore::new(concurrency.downloads);
// Determine the platform tags.
let interpreter = environment.interpreter();
@ -117,7 +119,10 @@ pub(crate) async fn pip_tree(
// Fetch the latest version for each package.
let mut fetches = futures::stream::iter(&packages)
.map(|(name, ..)| async {
let Some(filename) = client.find_latest(name, None).await? else {
let Some(filename) = client
.find_latest(name, None, &download_concurrency)
.await?
else {
return Ok(None);
};
Ok::<Option<_>, uv_client::Error>(Some((*name, filename.into_version())))

View File

@ -3,7 +3,7 @@ use std::path::Path;
use anstream::print;
use anyhow::{Error, Result};
use futures::StreamExt;
use tokio::sync::Semaphore;
use uv_cache::{Cache, Refresh};
use uv_cache_info::Timestamp;
use uv_client::{Connectivity, RegistryClientBuilder};
@ -225,6 +225,7 @@ pub(crate) async fn tree(
.keyring(*keyring_provider)
.allow_insecure_host(allow_insecure_host.to_vec())
.build();
let download_concurrency = Semaphore::new(concurrency.downloads);
// Initialize the client to fetch the latest version of each package.
let client = LatestClient {
@ -239,9 +240,12 @@ pub(crate) async fn tree(
let reporter = LatestVersionReporter::from(printer).with_length(packages.len() as u64);
// Fetch the latest version for each package.
let download_concurrency = &download_concurrency;
let mut fetches = futures::stream::iter(packages)
.map(|(package, index)| async move {
let Some(filename) = client.find_latest(package.name(), Some(&index)).await?
let Some(filename) = client
.find_latest(package.name(), Some(&index), download_concurrency)
.await?
else {
return Ok(None);
};

View File

@ -8,6 +8,7 @@ use std::fmt::Write;
use std::iter;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::{debug, info};
use url::Url;
use uv_cache::Cache;
@ -69,6 +70,8 @@ pub(crate) async fn publish(
let oidc_client = BaseClientBuilder::new()
.auth_integration(AuthIntegration::NoAuthMiddleware)
.wrap_existing(&upload_client);
// We're only checking a single URL and one at a time, so 1 permit is sufficient
let download_concurrency = Arc::new(Semaphore::new(1));
let (publish_url, username, password) = gather_credentials(
publish_url,
@ -110,7 +113,9 @@ pub(crate) async fn publish(
for (file, raw_filename, filename) in files {
if let Some(check_url_client) = &check_url_client {
if uv_publish::check_url(check_url_client, &file, &filename).await? {
if uv_publish::check_url(check_url_client, &file, &filename, &download_concurrency)
.await?
{
writeln!(printer.stderr(), "File {filename} already exists, skipping")?;
continue;
}
@ -134,6 +139,7 @@ pub(crate) async fn publish(
username.as_deref(),
password.as_deref(),
check_url_client.as_ref(),
&download_concurrency,
// Needs to be an `Arc` because the reqwest `Body` static lifetime requirement
Arc::new(reporter),
)