uv/crates/uv-distribution/src/distribution_database.rs

1495 lines
56 KiB
Rust

use std::future::Future;
use std::io;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::{FutureExt, TryStreamExt};
use tempfile::TempDir;
use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
use tokio::sync::Semaphore;
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tracing::{Instrument, debug, info_span, instrument, warn};
use url::Url;
use uv_cache::{ArchiveId, CacheBucket, CacheEntry, WheelCache};
use uv_cache_info::{CacheInfo, Timestamp};
use uv_client::{
CacheControl, CachedClientError, Connectivity, DataWithCachePolicy, RegistryClient,
};
use uv_distribution_filename::WheelFilename;
use uv_distribution_types::{
BuildInfo, BuildableSource, BuiltDist, CompatibleDist, Dist, File, HashPolicy, Hashed,
IndexUrl, InstalledDist, Name, RegistryBuiltDist, SourceDist, ToUrlError,
};
use uv_extract::hash::Hasher;
use uv_fs::write_atomic;
use uv_platform_tags::Tags;
use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml};
use uv_redacted::DisplaySafeUrl;
use uv_types::{BuildContext, BuildStack};
use crate::archive::Archive;
use crate::metadata::{ArchiveMetadata, Metadata};
use crate::remote::RemoteCacheResolver;
use crate::source::SourceDistributionBuilder;
use crate::{Error, LocalWheel, Reporter, RequiresDist};
/// A cached high-level interface to convert distributions (a requirement resolved to a location)
/// to a wheel or wheel metadata.
///
/// For wheel metadata, this happens by either fetching the metadata from the remote wheel or by
/// building the source distribution. For wheel files, either the wheel is downloaded or a source
/// distribution is downloaded, built and the new wheel gets returned.
///
/// All kinds of wheel sources (index, URL, path) and source distribution source (index, URL, path,
/// Git) are supported.
///
/// This struct also has the task of acquiring locks around source dist builds in general and git
/// operation especially, as well as respecting concurrency limits.
pub struct DistributionDatabase<'a, Context: BuildContext> {
build_context: &'a Context,
builder: SourceDistributionBuilder<'a, Context>,
resolver: RemoteCacheResolver<'a, Context>,
client: ManagedClient<'a>,
reporter: Option<Arc<dyn Reporter>>,
}
impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
pub fn new(
client: &'a RegistryClient,
build_context: &'a Context,
concurrent_downloads: usize,
) -> Self {
Self {
build_context,
builder: SourceDistributionBuilder::new(build_context),
resolver: RemoteCacheResolver::new(build_context),
client: ManagedClient::new(client, concurrent_downloads),
reporter: None,
}
}
/// Set the build stack to use for the [`DistributionDatabase`].
#[must_use]
pub fn with_build_stack(self, build_stack: &'a BuildStack) -> Self {
Self {
builder: self.builder.with_build_stack(build_stack),
..self
}
}
/// Set the [`Reporter`] to use for the [`DistributionDatabase`].
#[must_use]
pub fn with_reporter(self, reporter: Arc<dyn Reporter>) -> Self {
Self {
builder: self.builder.with_reporter(reporter.clone()),
reporter: Some(reporter),
..self
}
}
/// Handle a specific `reqwest` error, and convert it to [`io::Error`].
fn handle_response_errors(&self, err: reqwest::Error) -> io::Error {
if err.is_timeout() {
io::Error::new(
io::ErrorKind::TimedOut,
format!(
"Failed to download distribution due to network timeout. Try increasing UV_HTTP_TIMEOUT (current value: {}s).",
self.client.unmanaged.timeout().as_secs()
),
)
} else {
io::Error::other(err)
}
}
/// Either fetch the wheel or fetch and build the source distribution
///
/// Returns a wheel that's compliant with the given platform tags.
///
/// While hashes will be generated in some cases, hash-checking is only enforced for source
/// distributions, and should be enforced by the caller for wheels.
#[instrument(skip_all, fields(%dist))]
pub async fn get_or_build_wheel(
&self,
dist: &Dist,
tags: &Tags,
hashes: HashPolicy<'_>,
) -> Result<LocalWheel, Error> {
match dist {
Dist::Built(built) => self.get_wheel(built, hashes).await,
Dist::Source(source) => self.build_wheel(source, tags, hashes).await,
}
}
/// Either fetch the only wheel metadata (directly from the index or with range requests) or
/// fetch and build the source distribution.
///
/// While hashes will be generated in some cases, hash-checking is only enforced for source
/// distributions, and should be enforced by the caller for wheels.
#[instrument(skip_all, fields(%dist))]
pub async fn get_installed_metadata(
&self,
dist: &InstalledDist,
) -> Result<ArchiveMetadata, Error> {
// If the metadata was provided by the user directly, prefer it.
if let Some(metadata) = self
.build_context
.dependency_metadata()
.get(dist.name(), Some(dist.version()))
{
return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
}
let metadata = dist
.read_metadata()
.map_err(|err| Error::ReadInstalled(Box::new(dist.clone()), err))?;
Ok(ArchiveMetadata::from_metadata23(metadata.clone()))
}
/// Either fetch the only wheel metadata (directly from the index or with range requests) or
/// fetch and build the source distribution.
///
/// While hashes will be generated in some cases, hash-checking is only enforced for source
/// distributions, and should be enforced by the caller for wheels.
#[instrument(skip_all, fields(%dist))]
pub async fn get_or_build_wheel_metadata(
&self,
dist: &Dist,
hashes: HashPolicy<'_>,
) -> Result<ArchiveMetadata, Error> {
match dist {
Dist::Built(built) => self.get_wheel_metadata(built, hashes).await,
Dist::Source(source) => {
self.build_wheel_metadata(&BuildableSource::Dist(source), hashes)
.await
}
}
}
/// Fetch a wheel from the cache or download it from the index.
///
/// While hashes will be generated in all cases, hash-checking is _not_ enforced and should
/// instead be enforced by the caller.
async fn get_wheel(
&self,
dist: &BuiltDist,
hashes: HashPolicy<'_>,
) -> Result<LocalWheel, Error> {
match dist {
BuiltDist::Registry(wheels) => {
let wheel = wheels.best_wheel();
let WheelTarget {
url,
extension,
size,
} = WheelTarget::try_from(&*wheel.file)?;
// Create a cache entry for the wheel.
let wheel_entry = self.build_context.cache().entry(
CacheBucket::Wheels,
WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
wheel.filename.cache_key(),
);
// If the URL is a file URL, load the wheel directly.
if url.scheme() == "file" {
let path = url
.to_file_path()
.map_err(|()| Error::NonFileUrl(url.clone()))?;
return self
.load_wheel(
&path,
&wheel.filename,
WheelExtension::Whl,
wheel_entry,
dist,
hashes,
)
.await;
}
// Download and unzip.
match self
.stream_wheel(
url.clone(),
dist.index(),
&wheel.filename,
extension,
size,
&wheel_entry,
dist,
hashes,
)
.await
{
Ok(archive) => Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self
.build_context
.cache()
.archive(&archive.id)
.into_boxed_path(),
hashes: archive.hashes,
filename: wheel.filename.clone(),
cache: CacheInfo::default(),
build: None,
}),
Err(Error::Extract(name, err)) => {
if err.is_http_streaming_unsupported() {
warn!(
"Streaming unsupported for {dist}; downloading wheel to disk ({err})"
);
} else if err.is_http_streaming_failed() {
warn!("Streaming failed for {dist}; downloading wheel to disk ({err})");
} else {
return Err(Error::Extract(name, err));
}
// If the request failed because streaming is unsupported, download the
// wheel directly.
let archive = self
.download_wheel(
url,
dist.index(),
&wheel.filename,
extension,
size,
&wheel_entry,
dist,
hashes,
)
.await?;
Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self
.build_context
.cache()
.archive(&archive.id)
.into_boxed_path(),
hashes: archive.hashes,
filename: wheel.filename.clone(),
cache: CacheInfo::default(),
build: None,
})
}
Err(err) => Err(err),
}
}
BuiltDist::DirectUrl(wheel) => {
// Create a cache entry for the wheel.
let wheel_entry = self.build_context.cache().entry(
CacheBucket::Wheels,
WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
wheel.filename.cache_key(),
);
// Download and unzip.
match self
.stream_wheel(
wheel.url.raw().clone(),
None,
&wheel.filename,
WheelExtension::Whl,
None,
&wheel_entry,
dist,
hashes,
)
.await
{
Ok(archive) => Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self
.build_context
.cache()
.archive(&archive.id)
.into_boxed_path(),
hashes: archive.hashes,
filename: wheel.filename.clone(),
cache: CacheInfo::default(),
build: None,
}),
Err(Error::Client(err)) if err.is_http_streaming_unsupported() => {
warn!(
"Streaming unsupported for {dist}; downloading wheel to disk ({err})"
);
// If the request failed because streaming is unsupported, download the
// wheel directly.
let archive = self
.download_wheel(
wheel.url.raw().clone(),
None,
&wheel.filename,
WheelExtension::Whl,
None,
&wheel_entry,
dist,
hashes,
)
.await?;
Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self
.build_context
.cache()
.archive(&archive.id)
.into_boxed_path(),
hashes: archive.hashes,
filename: wheel.filename.clone(),
cache: CacheInfo::default(),
build: None,
})
}
Err(err) => Err(err),
}
}
BuiltDist::Path(wheel) => {
let cache_entry = self.build_context.cache().entry(
CacheBucket::Wheels,
WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
wheel.filename.cache_key(),
);
self.load_wheel(
&wheel.install_path,
&wheel.filename,
WheelExtension::Whl,
cache_entry,
dist,
hashes,
)
.await
}
}
}
/// Convert a source distribution into a wheel, fetching it from the cache or building it if
/// necessary.
///
/// The returned wheel is guaranteed to come from a distribution with a matching hash, and
/// no build processes will be executed for distributions with mismatched hashes.
async fn build_wheel(
&self,
dist: &SourceDist,
tags: &Tags,
hashes: HashPolicy<'_>,
) -> Result<LocalWheel, Error> {
// If the metadata is available in a remote cache, fetch it.
if let Ok(Some(wheel)) = self.get_remote_wheel(dist, tags, hashes).await {
return Ok(wheel);
}
// Otherwise, build the wheel locally.
self.build_wheel_inner(dist, tags, hashes).await
}
/// Convert a source distribution into a wheel, fetching it from the cache or building it if
/// necessary.
async fn build_wheel_inner(
&self,
dist: &SourceDist,
tags: &Tags,
hashes: HashPolicy<'_>,
) -> Result<LocalWheel, Error> {
let built_wheel = self
.builder
.download_and_build(&BuildableSource::Dist(dist), tags, hashes, &self.client)
.boxed_local()
.await?;
// Acquire the advisory lock.
#[cfg(windows)]
let _lock = {
let lock_entry = CacheEntry::new(
built_wheel.target.parent().unwrap(),
format!(
"{}.lock",
built_wheel.target.file_name().unwrap().to_str().unwrap()
),
);
lock_entry.lock().await.map_err(Error::CacheWrite)?
};
// If the wheel was unzipped previously, respect it. Source distributions are
// cached under a unique revision ID, so unzipped directories are never stale.
match self.build_context.cache().resolve_link(&built_wheel.target) {
Ok(archive) => {
return Ok(LocalWheel {
dist: Dist::Source(dist.clone()),
archive: archive.into_boxed_path(),
filename: built_wheel.filename,
hashes: built_wheel.hashes,
cache: built_wheel.cache_info,
build: Some(built_wheel.build_info),
});
}
Err(err) if err.kind() == io::ErrorKind::NotFound => {}
Err(err) => return Err(Error::CacheRead(err)),
}
// Otherwise, unzip the wheel.
let id = self
.unzip_wheel(&built_wheel.path, &built_wheel.target)
.await?;
Ok(LocalWheel {
dist: Dist::Source(dist.clone()),
archive: self.build_context.cache().archive(&id).into_boxed_path(),
hashes: built_wheel.hashes,
filename: built_wheel.filename,
cache: built_wheel.cache_info,
build: Some(built_wheel.build_info),
})
}
/// Fetch the wheel metadata from the index, or from the cache if possible.
///
/// While hashes will be generated in some cases, hash-checking is _not_ enforced and should
/// instead be enforced by the caller.
async fn get_wheel_metadata(
&self,
dist: &BuiltDist,
hashes: HashPolicy<'_>,
) -> Result<ArchiveMetadata, Error> {
// If hash generation is enabled, and the distribution isn't hosted on a registry, get the
// entire wheel to ensure that the hashes are included in the response. If the distribution
// is hosted on an index, the hashes will be included in the simple metadata response.
// For hash _validation_, callers are expected to enforce the policy when retrieving the
// wheel.
//
// Historically, for `uv pip compile --universal`, we also generate hashes for
// registry-based distributions when the relevant registry doesn't provide them. This was
// motivated by `--find-links`. We continue that behavior (under `HashGeneration::All`) for
// backwards compatibility, but it's a little dubious, since we're only hashing _one_
// distribution here (as opposed to hashing all distributions for the version), and it may
// not even be a compatible distribution!
//
// TODO(charlie): Request the hashes via a separate method, to reduce the coupling in this API.
if hashes.is_generate(dist) {
let wheel = self.get_wheel(dist, hashes).await?;
// If the metadata was provided by the user directly, prefer it.
let metadata = if let Some(metadata) = self
.build_context
.dependency_metadata()
.get(dist.name(), Some(dist.version()))
{
metadata.clone()
} else {
wheel.metadata()?
};
let hashes = wheel.hashes;
return Ok(ArchiveMetadata {
metadata: Metadata::from_metadata23(metadata),
hashes,
});
}
// If the metadata was provided by the user directly, prefer it.
if let Some(metadata) = self
.build_context
.dependency_metadata()
.get(dist.name(), Some(dist.version()))
{
return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
}
let result = self
.client
.managed(|client| {
client
.wheel_metadata(dist, self.build_context.capabilities())
.boxed_local()
})
.await;
match result {
Ok(metadata) => {
// Validate that the metadata is consistent with the distribution.
Ok(ArchiveMetadata::from_metadata23(metadata))
}
Err(err) if err.is_http_streaming_unsupported() => {
warn!(
"Streaming unsupported when fetching metadata for {dist}; downloading wheel directly ({err})"
);
// If the request failed due to an error that could be resolved by
// downloading the wheel directly, try that.
let wheel = self.get_wheel(dist, hashes).await?;
let metadata = wheel.metadata()?;
let hashes = wheel.hashes;
Ok(ArchiveMetadata {
metadata: Metadata::from_metadata23(metadata),
hashes,
})
}
Err(err) => Err(err.into()),
}
}
/// Build the wheel metadata for a source distribution, or fetch it from the cache if possible.
///
/// The returned metadata is guaranteed to come from a distribution with a matching hash, and
/// no build processes will be executed for distributions with mismatched hashes.
pub async fn build_wheel_metadata(
&self,
source: &BuildableSource<'_>,
hashes: HashPolicy<'_>,
) -> Result<ArchiveMetadata, Error> {
// Resolve the source distribution to a precise revision (i.e., a specific Git commit).
self.builder.resolve_revision(source, &self.client).await?;
// If the metadata was provided by the user directly, prefer it.
if let Some(dist) = source.as_dist() {
if let Some(metadata) = self
.build_context
.dependency_metadata()
.get(dist.name(), dist.version())
{
return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
}
}
// If the metadata is available in a remote cache, fetch it.
if let Ok(Some(metadata)) = self.get_remote_metadata(source, hashes).await {
return Ok(metadata);
}
// Otherwise, retrieve the metadata from the source distribution.
self.build_wheel_metadata_inner(source, hashes).await
}
/// Build the wheel metadata for a source distribution, or fetch it from the cache if possible.
async fn build_wheel_metadata_inner(
&self,
source: &BuildableSource<'_>,
hashes: HashPolicy<'_>,
) -> Result<ArchiveMetadata, Error> {
let metadata = self
.builder
.download_and_build_metadata(source, hashes, &self.client)
.boxed_local()
.await?;
Ok(metadata)
}
/// Fetch a wheel from a remote cache, if available.
async fn get_remote_wheel(
&self,
source: &SourceDist,
tags: &Tags,
hashes: HashPolicy<'_>,
) -> Result<Option<LocalWheel>, Error> {
let Some(index) = self
.resolver
.get_cached_distribution(&BuildableSource::Dist(source), Some(tags), &self.client)
.await?
else {
return Ok(None);
};
for prioritized_dist in index.iter() {
let Some(compatible_dist) = prioritized_dist.get() else {
continue;
};
match compatible_dist {
CompatibleDist::InstalledDist(..) => {}
CompatibleDist::SourceDist { sdist, .. } => {
debug!("Found cached remote source distribution for: {source}");
let dist = SourceDist::Registry(sdist.clone());
return self.build_wheel_inner(&dist, tags, hashes).await.map(Some);
}
CompatibleDist::CompatibleWheel { wheel, .. }
| CompatibleDist::IncompatibleWheel { wheel, .. } => {
debug!("Found cached remote built distribution for: {source}");
let dist = BuiltDist::Registry(RegistryBuiltDist {
wheels: vec![wheel.clone()],
best_wheel_index: 0,
sdist: None,
});
return self.get_wheel(&dist, hashes).await.map(Some);
}
}
}
Ok(None)
}
/// Fetch the wheel metadata from a remote cache, if available.
async fn get_remote_metadata(
&self,
source: &BuildableSource<'_>,
hashes: HashPolicy<'_>,
) -> Result<Option<ArchiveMetadata>, Error> {
let Some(index) = self
.resolver
.get_cached_distribution(source, None, &self.client)
.await?
else {
return Ok(None);
};
for prioritized_dist in index.iter() {
let Some(compatible_dist) = prioritized_dist.get() else {
continue;
};
match compatible_dist {
CompatibleDist::InstalledDist(..) => {}
CompatibleDist::SourceDist { sdist, .. } => {
debug!("Found cached remote source distribution for: {source}");
let dist = SourceDist::Registry(sdist.clone());
return self
.build_wheel_metadata_inner(&BuildableSource::Dist(&dist), hashes)
.await
.map(Some);
}
CompatibleDist::CompatibleWheel { wheel, .. }
| CompatibleDist::IncompatibleWheel { wheel, .. } => {
debug!("Found cached remote built distribution for: {source}");
let dist = BuiltDist::Registry(RegistryBuiltDist {
wheels: vec![wheel.clone()],
best_wheel_index: 0,
sdist: None,
});
return self.get_wheel_metadata(&dist, hashes).await.map(Some);
}
}
}
Ok(None)
}
/// Return the [`RequiresDist`] from a `pyproject.toml`, if it can be statically extracted.
pub async fn requires_dist(
&self,
path: &Path,
pyproject_toml: &PyProjectToml,
) -> Result<Option<RequiresDist>, Error> {
self.builder
.source_tree_requires_dist(path, pyproject_toml)
.await
}
/// Stream a wheel from a URL, unzipping it into the cache as it's downloaded.
async fn stream_wheel(
&self,
url: DisplaySafeUrl,
index: Option<&IndexUrl>,
filename: &WheelFilename,
extension: WheelExtension,
size: Option<u64>,
wheel_entry: &CacheEntry,
dist: &BuiltDist,
hashes: HashPolicy<'_>,
) -> Result<Archive, Error> {
// Acquire an advisory lock, to guard against concurrent writes.
#[cfg(windows)]
let _lock = {
let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
lock_entry.lock().await.map_err(Error::CacheWrite)?
};
// Create an entry for the HTTP cache.
let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
let download = |response: reqwest::Response| {
async {
let size = size.or_else(|| content_length(&response));
let progress = self
.reporter
.as_ref()
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
let reader = response
.bytes_stream()
.map_err(|err| self.handle_response_errors(err))
.into_async_read();
// Create a hasher for each hash algorithm.
let algorithms = hashes.algorithms();
let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
let mut hasher = uv_extract::hash::HashReader::new(reader.compat(), &mut hashers);
// Download and unzip the wheel to a temporary directory.
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
.map_err(Error::CacheWrite)?;
match progress {
Some((reporter, progress)) => {
let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
match extension {
WheelExtension::Whl => {
uv_extract::stream::unzip(&mut reader, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
WheelExtension::WhlZst => {
uv_extract::stream::untar_zst(&mut reader, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
}
}
None => match extension {
WheelExtension::Whl => {
uv_extract::stream::unzip(&mut hasher, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
WheelExtension::WhlZst => {
uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
},
}
// If necessary, exhaust the reader to compute the hash.
if !hashes.is_none() {
hasher.finish().await.map_err(Error::HashExhaustion)?;
}
// Persist the temporary directory to the directory store.
let id = self
.build_context
.cache()
.persist(temp_dir.keep(), wheel_entry.path())
.await
.map_err(Error::CacheRead)?;
if let Some((reporter, progress)) = progress {
reporter.on_download_complete(dist.name(), progress);
}
Ok(Archive::new(
id,
hashers.into_iter().map(HashDigest::from).collect(),
filename.clone(),
))
}
.instrument(info_span!("wheel", wheel = %dist))
};
// Fetch the archive from the cache, or download it if necessary.
let req = self.request(url.clone())?;
// Determine the cache control policy for the URL.
let cache_control = match self.client.unmanaged.connectivity() {
Connectivity::Online => {
if let Some(header) = index.and_then(|index| {
self.build_context
.locations()
.artifact_cache_control_for(index)
}) {
CacheControl::Override(header)
} else {
CacheControl::from(
self.build_context
.cache()
.freshness(&http_entry, Some(&filename.name), None)
.map_err(Error::CacheRead)?,
)
}
}
Connectivity::Offline => CacheControl::AllowStale,
};
let archive = self
.client
.managed(|client| {
client.cached_client().get_serde_with_retry(
req,
&http_entry,
cache_control,
download,
)
})
.await
.map_err(|err| match err {
CachedClientError::Callback { err, .. } => err,
CachedClientError::Client { err, .. } => Error::Client(err),
})?;
// If the archive is missing the required hashes, or has since been removed, force a refresh.
let archive = Some(archive)
.filter(|archive| archive.has_digests(hashes))
.filter(|archive| archive.exists(self.build_context.cache()));
let archive = if let Some(archive) = archive {
archive
} else {
self.client
.managed(async |client| {
client
.cached_client()
.skip_cache_with_retry(
self.request(url)?,
&http_entry,
cache_control,
download,
)
.await
.map_err(|err| match err {
CachedClientError::Callback { err, .. } => err,
CachedClientError::Client { err, .. } => Error::Client(err),
})
})
.await?
};
Ok(archive)
}
/// Download a wheel from a URL, then unzip it into the cache.
async fn download_wheel(
&self,
url: DisplaySafeUrl,
index: Option<&IndexUrl>,
filename: &WheelFilename,
extension: WheelExtension,
size: Option<u64>,
wheel_entry: &CacheEntry,
dist: &BuiltDist,
hashes: HashPolicy<'_>,
) -> Result<Archive, Error> {
// Acquire an advisory lock, to guard against concurrent writes.
#[cfg(windows)]
let _lock = {
let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
lock_entry.lock().await.map_err(Error::CacheWrite)?
};
// Create an entry for the HTTP cache.
let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
let download = |response: reqwest::Response| {
async {
let size = size.or_else(|| content_length(&response));
let progress = self
.reporter
.as_ref()
.map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
let reader = response
.bytes_stream()
.map_err(|err| self.handle_response_errors(err))
.into_async_read();
// Download the wheel to a temporary file.
let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
.map_err(Error::CacheWrite)?;
let mut writer = tokio::io::BufWriter::new(fs_err::tokio::File::from_std(
// It's an unnamed file on Linux so that's the best approximation.
fs_err::File::from_parts(temp_file, self.build_context.cache().root()),
));
match progress {
Some((reporter, progress)) => {
// Wrap the reader in a progress reporter. This will report 100% progress
// after the download is complete, even if we still have to unzip and hash
// part of the file.
let mut reader =
ProgressReader::new(reader.compat(), progress, &**reporter);
tokio::io::copy(&mut reader, &mut writer)
.await
.map_err(Error::CacheWrite)?;
}
None => {
tokio::io::copy(&mut reader.compat(), &mut writer)
.await
.map_err(Error::CacheWrite)?;
}
}
// Unzip the wheel to a temporary directory.
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
.map_err(Error::CacheWrite)?;
let mut file = writer.into_inner();
file.seek(io::SeekFrom::Start(0))
.await
.map_err(Error::CacheWrite)?;
// If no hashes are required, parallelize the unzip operation.
let hashes = if hashes.is_none() {
let file = file.into_std().await;
tokio::task::spawn_blocking({
let target = temp_dir.path().to_owned();
move || -> Result<(), uv_extract::Error> {
// Unzip the wheel into a temporary directory.
match extension {
WheelExtension::Whl => {
uv_extract::unzip(file, &target)?;
}
WheelExtension::WhlZst => {
uv_extract::stream::untar_zst_file(file, &target)?;
}
}
Ok(())
}
})
.await?
.map_err(|err| Error::Extract(filename.to_string(), err))?;
HashDigests::empty()
} else {
// Create a hasher for each hash algorithm.
let algorithms = hashes.algorithms();
let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
match extension {
WheelExtension::Whl => {
uv_extract::stream::unzip(&mut hasher, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
WheelExtension::WhlZst => {
uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
}
// If necessary, exhaust the reader to compute the hash.
hasher.finish().await.map_err(Error::HashExhaustion)?;
hashers.into_iter().map(HashDigest::from).collect()
};
// Persist the temporary directory to the directory store.
let id = self
.build_context
.cache()
.persist(temp_dir.keep(), wheel_entry.path())
.await
.map_err(Error::CacheRead)?;
if let Some((reporter, progress)) = progress {
reporter.on_download_complete(dist.name(), progress);
}
Ok(Archive::new(id, hashes, filename.clone()))
}
.instrument(info_span!("wheel", wheel = %dist))
};
// Fetch the archive from the cache, or download it if necessary.
let req = self.request(url.clone())?;
// Determine the cache control policy for the URL.
let cache_control = match self.client.unmanaged.connectivity() {
Connectivity::Online => {
if let Some(header) = index.and_then(|index| {
self.build_context
.locations()
.artifact_cache_control_for(index)
}) {
CacheControl::Override(header)
} else {
CacheControl::from(
self.build_context
.cache()
.freshness(&http_entry, Some(&filename.name), None)
.map_err(Error::CacheRead)?,
)
}
}
Connectivity::Offline => CacheControl::AllowStale,
};
let archive = self
.client
.managed(|client| {
client.cached_client().get_serde_with_retry(
req,
&http_entry,
cache_control,
download,
)
})
.await
.map_err(|err| match err {
CachedClientError::Callback { err, .. } => err,
CachedClientError::Client { err, .. } => Error::Client(err),
})?;
// If the archive is missing the required hashes, or has since been removed, force a refresh.
let archive = Some(archive)
.filter(|archive| archive.has_digests(hashes))
.filter(|archive| archive.exists(self.build_context.cache()));
let archive = if let Some(archive) = archive {
archive
} else {
self.client
.managed(async |client| {
client
.cached_client()
.skip_cache_with_retry(
self.request(url)?,
&http_entry,
cache_control,
download,
)
.await
.map_err(|err| match err {
CachedClientError::Callback { err, .. } => err,
CachedClientError::Client { err, .. } => Error::Client(err),
})
})
.await?
};
Ok(archive)
}
/// Load a wheel from a local path.
async fn load_wheel(
&self,
path: &Path,
filename: &WheelFilename,
extension: WheelExtension,
wheel_entry: CacheEntry,
dist: &BuiltDist,
hashes: HashPolicy<'_>,
) -> Result<LocalWheel, Error> {
#[cfg(windows)]
let _lock = {
let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
lock_entry.lock().await.map_err(Error::CacheWrite)?
};
// Determine the last-modified time of the wheel.
let modified = Timestamp::from_path(path).map_err(Error::CacheRead)?;
// Attempt to read the archive pointer from the cache.
let pointer_entry = wheel_entry.with_file(format!("{}.rev", filename.cache_key()));
let pointer = LocalArchivePointer::read_from(&pointer_entry)?;
// Extract the archive from the pointer.
let archive = pointer
.filter(|pointer| pointer.is_up_to_date(modified))
.map(LocalArchivePointer::into_archive)
.filter(|archive| archive.has_digests(hashes));
// If the file is already unzipped, and the cache is up-to-date, return it.
if let Some(archive) = archive {
Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self
.build_context
.cache()
.archive(&archive.id)
.into_boxed_path(),
hashes: archive.hashes,
filename: filename.clone(),
cache: CacheInfo::from_timestamp(modified),
build: None,
})
} else if hashes.is_none() {
// Otherwise, unzip the wheel.
let archive = Archive::new(
self.unzip_wheel(path, wheel_entry.path()).await?,
HashDigests::empty(),
filename.clone(),
);
// Write the archive pointer to the cache.
let pointer = LocalArchivePointer {
timestamp: modified,
archive: archive.clone(),
};
pointer.write_to(&pointer_entry).await?;
Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self
.build_context
.cache()
.archive(&archive.id)
.into_boxed_path(),
hashes: archive.hashes,
filename: filename.clone(),
cache: CacheInfo::from_timestamp(modified),
build: None,
})
} else {
// If necessary, compute the hashes of the wheel.
let file = fs_err::tokio::File::open(path)
.await
.map_err(Error::CacheRead)?;
let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
.map_err(Error::CacheWrite)?;
// Create a hasher for each hash algorithm.
let algorithms = hashes.algorithms();
let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
// Unzip the wheel to a temporary directory.
match extension {
WheelExtension::Whl => {
uv_extract::stream::unzip(&mut hasher, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
WheelExtension::WhlZst => {
uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
.await
.map_err(|err| Error::Extract(filename.to_string(), err))?;
}
}
// Exhaust the reader to compute the hash.
hasher.finish().await.map_err(Error::HashExhaustion)?;
let hashes = hashers.into_iter().map(HashDigest::from).collect();
// Persist the temporary directory to the directory store.
let id = self
.build_context
.cache()
.persist(temp_dir.keep(), wheel_entry.path())
.await
.map_err(Error::CacheWrite)?;
// Create an archive.
let archive = Archive::new(id, hashes, filename.clone());
// Write the archive pointer to the cache.
let pointer = LocalArchivePointer {
timestamp: modified,
archive: archive.clone(),
};
pointer.write_to(&pointer_entry).await?;
Ok(LocalWheel {
dist: Dist::Built(dist.clone()),
archive: self
.build_context
.cache()
.archive(&archive.id)
.into_boxed_path(),
hashes: archive.hashes,
filename: filename.clone(),
cache: CacheInfo::from_timestamp(modified),
build: None,
})
}
}
/// Unzip a wheel into the cache, returning the path to the unzipped directory.
async fn unzip_wheel(&self, path: &Path, target: &Path) -> Result<ArchiveId, Error> {
let temp_dir = tokio::task::spawn_blocking({
let path = path.to_owned();
let root = self.build_context.cache().root().to_path_buf();
move || -> Result<TempDir, Error> {
// Unzip the wheel into a temporary directory.
let temp_dir = tempfile::tempdir_in(root).map_err(Error::CacheWrite)?;
let reader = fs_err::File::open(&path).map_err(Error::CacheWrite)?;
uv_extract::unzip(reader, temp_dir.path())
.map_err(|err| Error::Extract(path.to_string_lossy().into_owned(), err))?;
Ok(temp_dir)
}
})
.await??;
// Persist the temporary directory to the directory store.
let id = self
.build_context
.cache()
.persist(temp_dir.keep(), target)
.await
.map_err(Error::CacheWrite)?;
Ok(id)
}
/// Returns a GET [`reqwest::Request`] for the given URL.
fn request(&self, url: DisplaySafeUrl) -> Result<reqwest::Request, reqwest::Error> {
self.client
.unmanaged
.uncached_client(&url)
.get(Url::from(url))
.header(
// `reqwest` defaults to accepting compressed responses.
// Specify identity encoding to get consistent .whl downloading
// behavior from servers. ref: https://github.com/pypa/pip/pull/1688
"accept-encoding",
reqwest::header::HeaderValue::from_static("identity"),
)
.build()
}
/// Return the [`ManagedClient`] used by this resolver.
pub fn client(&self) -> &ManagedClient<'a> {
&self.client
}
}
/// A wrapper around `RegistryClient` that manages a concurrency limit.
pub struct ManagedClient<'a> {
pub unmanaged: &'a RegistryClient,
control: Semaphore,
}
impl<'a> ManagedClient<'a> {
/// Create a new `ManagedClient` using the given client and concurrency limit.
fn new(client: &'a RegistryClient, concurrency: usize) -> Self {
ManagedClient {
unmanaged: client,
control: Semaphore::new(concurrency),
}
}
/// Perform a request using the client, respecting the concurrency limit.
///
/// If the concurrency limit has been reached, this method will wait until a pending
/// operation completes before executing the closure.
pub async fn managed<F, T>(&self, f: impl FnOnce(&'a RegistryClient) -> F) -> T
where
F: Future<Output = T>,
{
let _permit = self.control.acquire().await.unwrap();
f(self.unmanaged).await
}
/// Perform a request using a client that internally manages the concurrency limit.
///
/// The callback is passed the client and a semaphore. It must acquire the semaphore before
/// any request through the client and drop it after.
///
/// This method serves as an escape hatch for functions that may want to send multiple requests
/// in parallel.
pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
where
F: Future<Output = T>,
{
f(self.unmanaged, &self.control).await
}
}
/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.
fn content_length(response: &reqwest::Response) -> Option<u64> {
response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|val| val.to_str().ok())
.and_then(|val| val.parse::<u64>().ok())
}
/// An asynchronous reader that reports progress as bytes are read.
struct ProgressReader<'a, R> {
reader: R,
index: usize,
reporter: &'a dyn Reporter,
}
impl<'a, R> ProgressReader<'a, R> {
/// Create a new [`ProgressReader`] that wraps another reader.
fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
Self {
reader,
index,
reporter,
}
}
}
impl<R> AsyncRead for ProgressReader<'_, R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.as_mut().reader)
.poll_read(cx, buf)
.map_ok(|()| {
self.reporter
.on_download_progress(self.index, buf.filled().len() as u64);
})
}
}
/// A pointer to an archive in the cache, fetched from an HTTP archive.
///
/// Encoded with `MsgPack`, and represented on disk by a `.http` file.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HttpArchivePointer {
archive: Archive,
}
impl HttpArchivePointer {
/// Read an [`HttpArchivePointer`] from the cache.
pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
match fs_err::File::open(path.as_ref()) {
Ok(file) => {
let data = DataWithCachePolicy::from_reader(file)?.data;
let archive = rmp_serde::from_slice::<Archive>(&data)?;
Ok(Some(Self { archive }))
}
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(Error::CacheRead(err)),
}
}
/// Return the [`Archive`] from the pointer.
pub fn into_archive(self) -> Archive {
self.archive
}
/// Return the [`CacheInfo`] from the pointer.
pub fn to_cache_info(&self) -> CacheInfo {
CacheInfo::default()
}
/// Return the [`BuildInfo`] from the pointer.
pub fn to_build_info(&self) -> Option<BuildInfo> {
None
}
}
/// A pointer to an archive in the cache, fetched from a local path.
///
/// Encoded with `MsgPack`, and represented on disk by a `.rev` file.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LocalArchivePointer {
timestamp: Timestamp,
archive: Archive,
}
impl LocalArchivePointer {
/// Read an [`LocalArchivePointer`] from the cache.
pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
match fs_err::read(path) {
Ok(cached) => Ok(Some(rmp_serde::from_slice::<Self>(&cached)?)),
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(Error::CacheRead(err)),
}
}
/// Write an [`LocalArchivePointer`] to the cache.
pub async fn write_to(&self, entry: &CacheEntry) -> Result<(), Error> {
write_atomic(entry.path(), rmp_serde::to_vec(&self)?)
.await
.map_err(Error::CacheWrite)
}
/// Returns `true` if the archive is up-to-date with the given modified timestamp.
pub fn is_up_to_date(&self, modified: Timestamp) -> bool {
self.timestamp == modified
}
/// Return the [`Archive`] from the pointer.
pub fn into_archive(self) -> Archive {
self.archive
}
/// Return the [`CacheInfo`] from the pointer.
pub fn to_cache_info(&self) -> CacheInfo {
CacheInfo::from_timestamp(self.timestamp)
}
/// Return the [`BuildInfo`] from the pointer.
pub fn to_build_info(&self) -> Option<BuildInfo> {
None
}
}
#[derive(Debug, Clone)]
struct WheelTarget {
/// The URL from which the wheel can be downloaded.
url: DisplaySafeUrl,
/// The expected extension of the wheel file.
extension: WheelExtension,
/// The expected size of the wheel file, if known.
size: Option<u64>,
}
impl TryFrom<&File> for WheelTarget {
type Error = ToUrlError;
/// Determine the [`WheelTarget`] from a [`File`].
fn try_from(file: &File) -> Result<Self, Self::Error> {
let url = file.url.to_url()?;
if let Some(zstd) = file.zstd.as_ref() {
Ok(Self {
url: add_tar_zst_extension(url),
extension: WheelExtension::WhlZst,
size: zstd.size,
})
} else {
Ok(Self {
url,
extension: WheelExtension::Whl,
size: file.size,
})
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum WheelExtension {
/// A `.whl` file.
Whl,
/// A `.whl.tar.zst` file.
WhlZst,
}
/// Add `.tar.zst` to the end of the URL path, if it doesn't already exist.
#[must_use]
fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl {
let mut path = url.path().to_string();
if !path.ends_with(".tar.zst") {
path.push_str(".tar.zst");
}
url.set_path(&path);
url
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_tar_zst_extension() {
let url =
DisplaySafeUrl::parse("https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl")
.unwrap();
assert_eq!(
add_tar_zst_extension(url).as_str(),
"https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
);
let url = DisplaySafeUrl::parse(
"https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst",
)
.unwrap();
assert_eq!(
add_tar_zst_extension(url).as_str(),
"https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
);
let url = DisplaySafeUrl::parse(
"https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl",
)
.unwrap();
assert_eq!(
add_tar_zst_extension(url).as_str(),
"https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl.tar.zst"
);
}
}