mirror of https://github.com/astral-sh/uv
385 lines
14 KiB
Rust
385 lines
14 KiB
Rust
use std::future::Future;
|
|
use std::time::SystemTime;
|
|
|
|
use futures::FutureExt;
|
|
use http::request::Parts;
|
|
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
|
|
use reqwest::{Request, Response};
|
|
use reqwest_middleware::ClientWithMiddleware;
|
|
use serde::de::DeserializeOwned;
|
|
use serde::{Deserialize, Serialize};
|
|
use tracing::{debug, info_span, instrument, trace, warn, Instrument};
|
|
use url::Url;
|
|
|
|
use puffin_cache::{CacheEntry, Freshness};
|
|
use puffin_fs::write_atomic;
|
|
|
|
use crate::{cache_headers::CacheHeaders, Error, ErrorKind};
|
|
|
|
/// Either a cached client error or a (user specified) error from the callback
|
|
#[derive(Debug)]
|
|
pub enum CachedClientError<CallbackError> {
|
|
Client(Error),
|
|
Callback(CallbackError),
|
|
}
|
|
|
|
impl<CallbackError> From<Error> for CachedClientError<CallbackError> {
|
|
fn from(error: Error) -> Self {
|
|
CachedClientError::Client(error)
|
|
}
|
|
}
|
|
|
|
impl<CallbackError> From<ErrorKind> for CachedClientError<CallbackError> {
|
|
fn from(error: ErrorKind) -> Self {
|
|
CachedClientError::Client(error.into())
|
|
}
|
|
}
|
|
|
|
impl<E: Into<Error>> From<CachedClientError<E>> for Error {
|
|
fn from(error: CachedClientError<E>) -> Error {
|
|
match error {
|
|
CachedClientError::Client(error) => error,
|
|
CachedClientError::Callback(error) => error.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum CachedResponse<Payload: Serialize> {
|
|
/// The cached response is fresh without an HTTP request (e.g. immutable)
|
|
FreshCache(Payload),
|
|
/// The cached response is fresh after an HTTP request (e.g. 304 not modified)
|
|
NotModified(DataWithCachePolicy<Payload>),
|
|
/// There was no prior cached response or the cache was outdated
|
|
///
|
|
/// The cache policy is `None` if it isn't storable
|
|
ModifiedOrNew(Response, Option<Box<CachePolicy>>),
|
|
}
|
|
|
|
/// Serialize the actual payload together with its caching information.
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
pub struct DataWithCachePolicy<Payload: Serialize> {
|
|
pub data: Payload,
|
|
/// Whether the response should be considered immutable.
|
|
immutable: bool,
|
|
/// The [`CachePolicy`] is used to determine if the response is fresh or stale.
|
|
/// The policy is large (448 bytes at time of writing), so we reduce the stack size by
|
|
/// boxing it.
|
|
cache_policy: Box<CachePolicy>,
|
|
}
|
|
|
|
/// Custom caching layer over [`reqwest::Client`] using `http-cache-semantics`.
|
|
///
|
|
/// The implementation takes inspiration from the `http-cache` crate, but adds support for running
|
|
/// an async callback on the response before caching. We use this to e.g. store a
|
|
/// parsed version of the wheel metadata and for our remote zip reader. In the latter case, we want
|
|
/// to read a single file from a remote zip using range requests (so we don't have to download the
|
|
/// entire file). We send a HEAD request in the caching layer to check if the remote file has
|
|
/// changed (and if range requests are supported), and in the callback we make the actual range
|
|
/// requests if required.
|
|
///
|
|
/// Unlike `http-cache`, all outputs must be serde-able. Currently everything is json, but we can
|
|
/// transparently switch to a faster/smaller format.
|
|
///
|
|
/// Again unlike `http-cache`, the caller gets full control over the cache key with the assumption
|
|
/// that it's a file.
|
|
#[derive(Debug, Clone)]
|
|
pub struct CachedClient(ClientWithMiddleware);
|
|
|
|
impl CachedClient {
|
|
pub fn new(client: ClientWithMiddleware) -> Self {
|
|
Self(client)
|
|
}
|
|
|
|
/// The middleware is the retry strategy
|
|
pub fn uncached(&self) -> ClientWithMiddleware {
|
|
self.0.clone()
|
|
}
|
|
|
|
/// Make a cached request with a custom response transformation
|
|
///
|
|
/// If a new response was received (no prior cached response or modified on the remote), the
|
|
/// response is passed through `response_callback` and only the result is cached and returned.
|
|
/// The `response_callback` is allowed to make subsequent requests, e.g. through the uncached
|
|
/// client.
|
|
#[instrument(skip_all)]
|
|
pub async fn get_cached_with_callback<
|
|
Payload: Serialize + DeserializeOwned + Send + 'static,
|
|
CallBackError,
|
|
Callback,
|
|
CallbackReturn,
|
|
>(
|
|
&self,
|
|
req: Request,
|
|
cache_entry: &CacheEntry,
|
|
cache_control: CacheControl,
|
|
response_callback: Callback,
|
|
) -> Result<Payload, CachedClientError<CallBackError>>
|
|
where
|
|
Callback: FnOnce(Response) -> CallbackReturn,
|
|
CallbackReturn: Future<Output = Result<Payload, CallBackError>> + Send,
|
|
{
|
|
let cached = Self::read_cache(cache_entry).await;
|
|
|
|
let cached_response = self.send_cached(req, cache_control, cached).boxed().await?;
|
|
|
|
let write_cache = info_span!("write_cache", file = %cache_entry.path().display());
|
|
match cached_response {
|
|
CachedResponse::FreshCache(data) => Ok(data),
|
|
CachedResponse::NotModified(data_with_cache_policy) => {
|
|
async {
|
|
let data =
|
|
rmp_serde::to_vec(&data_with_cache_policy).map_err(ErrorKind::Encode)?;
|
|
write_atomic(cache_entry.path(), data)
|
|
.await
|
|
.map_err(ErrorKind::CacheWrite)?;
|
|
Ok(data_with_cache_policy.data)
|
|
}
|
|
.instrument(write_cache)
|
|
.await
|
|
}
|
|
CachedResponse::ModifiedOrNew(res, cache_policy) => {
|
|
let headers = CacheHeaders::from_response(res.headers().get_all("cache-control"));
|
|
let immutable = headers.is_immutable();
|
|
|
|
let data = response_callback(res)
|
|
.boxed()
|
|
.await
|
|
.map_err(|err| CachedClientError::Callback(err))?;
|
|
if let Some(cache_policy) = cache_policy {
|
|
let data_with_cache_policy = DataWithCachePolicy {
|
|
data,
|
|
immutable,
|
|
cache_policy,
|
|
};
|
|
async {
|
|
fs_err::tokio::create_dir_all(cache_entry.dir())
|
|
.await
|
|
.map_err(ErrorKind::CacheWrite)?;
|
|
let data = rmp_serde::to_vec(&data_with_cache_policy)
|
|
.map_err(ErrorKind::Encode)?;
|
|
write_atomic(cache_entry.path(), data)
|
|
.await
|
|
.map_err(ErrorKind::CacheWrite)?;
|
|
Ok(data_with_cache_policy.data)
|
|
}
|
|
.instrument(write_cache)
|
|
.await
|
|
} else {
|
|
Ok(data)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn read_cache<Payload: Serialize + DeserializeOwned + Send + 'static>(
|
|
cache_entry: &CacheEntry,
|
|
) -> Option<DataWithCachePolicy<Payload>> {
|
|
let read_span = info_span!("read_cache", file = %cache_entry.path().display());
|
|
let read_result = fs_err::tokio::read(cache_entry.path())
|
|
.instrument(read_span)
|
|
.await;
|
|
|
|
if let Ok(cached) = read_result {
|
|
let parse_span = info_span!(
|
|
"parse_cache",
|
|
path = %cache_entry.path().display()
|
|
);
|
|
let parse_result = tokio::task::spawn_blocking(move || {
|
|
parse_span
|
|
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached))
|
|
})
|
|
.await
|
|
.expect("Tokio executor failed, was there a panic?");
|
|
match parse_result {
|
|
Ok(data) => Some(data),
|
|
Err(err) => {
|
|
warn!(
|
|
"Broken cache entry at {}, removing: {err}",
|
|
cache_entry.path().display()
|
|
);
|
|
let _ = fs_err::tokio::remove_file(&cache_entry.path()).await;
|
|
None
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// `http-cache-semantics` to `reqwest` wrapper
|
|
async fn send_cached<T: Serialize + DeserializeOwned>(
|
|
&self,
|
|
mut req: Request,
|
|
cache_control: CacheControl,
|
|
cached: Option<DataWithCachePolicy<T>>,
|
|
) -> Result<CachedResponse<T>, Error> {
|
|
let url = req.url().clone();
|
|
let cached_response = if let Some(cached) = cached {
|
|
// Avoid sending revalidation requests for immutable responses.
|
|
if cached.immutable && !cached.cache_policy.is_stale(SystemTime::now()) {
|
|
debug!("Found immutable response for: {url}");
|
|
return Ok(CachedResponse::FreshCache(cached.data));
|
|
}
|
|
|
|
// Apply the cache control header, if necessary.
|
|
match cache_control {
|
|
CacheControl::None => {}
|
|
CacheControl::MustRevalidate => {
|
|
req.headers_mut().insert(
|
|
http::header::CACHE_CONTROL,
|
|
http::HeaderValue::from_static("max-age=0, must-revalidate"),
|
|
);
|
|
}
|
|
}
|
|
|
|
match cached
|
|
.cache_policy
|
|
.before_request(&RequestLikeReqwest(&req), SystemTime::now())
|
|
{
|
|
BeforeRequest::Fresh(_) => {
|
|
debug!("Found fresh response for: {url}");
|
|
CachedResponse::FreshCache(cached.data)
|
|
}
|
|
BeforeRequest::Stale { request, matches } => {
|
|
self.send_cached_handle_stale(req, url, cached, &request, matches)
|
|
.await?
|
|
}
|
|
}
|
|
} else {
|
|
debug!("No cache entry for: {url}");
|
|
self.fresh_request(req).await?
|
|
};
|
|
Ok(cached_response)
|
|
}
|
|
|
|
async fn send_cached_handle_stale<T: Serialize + DeserializeOwned>(
|
|
&self,
|
|
mut req: Request,
|
|
url: Url,
|
|
cached: DataWithCachePolicy<T>,
|
|
request: &Parts,
|
|
matches: bool,
|
|
) -> Result<CachedResponse<T>, Error> {
|
|
if !matches {
|
|
// This shouldn't happen; if it does, we'll override the cache.
|
|
warn!("Cached request doesn't match current request for: {url}");
|
|
return self.fresh_request(req).await;
|
|
}
|
|
|
|
debug!("Sending revalidation request for: {url}");
|
|
for header in &request.headers {
|
|
req.headers_mut().insert(header.0.clone(), header.1.clone());
|
|
}
|
|
let res = self
|
|
.0
|
|
.execute(req.try_clone().expect("streaming requests not supported"))
|
|
.instrument(info_span!("revalidation_request", url = url.as_str()))
|
|
.await
|
|
.map_err(ErrorKind::RequestMiddlewareError)?
|
|
.error_for_status()
|
|
.map_err(ErrorKind::RequestError)?;
|
|
let after_response = cached.cache_policy.after_response(
|
|
&RequestLikeReqwest(&req),
|
|
&ResponseLikeReqwest(&res),
|
|
SystemTime::now(),
|
|
);
|
|
match after_response {
|
|
AfterResponse::NotModified(new_policy, _parts) => {
|
|
debug!("Found not-modified response for: {url}");
|
|
let headers = CacheHeaders::from_response(res.headers().get_all("cache-control"));
|
|
let immutable = headers.is_immutable();
|
|
Ok(CachedResponse::NotModified(DataWithCachePolicy {
|
|
data: cached.data,
|
|
immutable,
|
|
cache_policy: Box::new(new_policy),
|
|
}))
|
|
}
|
|
AfterResponse::Modified(new_policy, _parts) => {
|
|
debug!("Found modified response for: {url}");
|
|
Ok(CachedResponse::ModifiedOrNew(
|
|
res,
|
|
new_policy.is_storable().then(|| Box::new(new_policy)),
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[instrument(skip_all, fields(url = req.url().as_str()))]
|
|
async fn fresh_request<T: Serialize>(&self, req: Request) -> Result<CachedResponse<T>, Error> {
|
|
trace!("{} {}", req.method(), req.url());
|
|
let res = self
|
|
.0
|
|
.execute(req.try_clone().expect("streaming requests not supported"))
|
|
.await
|
|
.map_err(ErrorKind::RequestMiddlewareError)?
|
|
.error_for_status()
|
|
.map_err(ErrorKind::RequestError)?;
|
|
let cache_policy = CachePolicy::new(&RequestLikeReqwest(&req), &ResponseLikeReqwest(&res));
|
|
Ok(CachedResponse::ModifiedOrNew(
|
|
res,
|
|
cache_policy.is_storable().then(|| Box::new(cache_policy)),
|
|
))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub enum CacheControl {
|
|
/// Respect the `cache-control` header from the response.
|
|
None,
|
|
/// Apply `max-age=0, must-revalidate` to the request.
|
|
MustRevalidate,
|
|
}
|
|
|
|
impl From<Freshness> for CacheControl {
|
|
fn from(value: Freshness) -> Self {
|
|
match value {
|
|
Freshness::Fresh => CacheControl::None,
|
|
Freshness::Stale => CacheControl::MustRevalidate,
|
|
Freshness::Missing => CacheControl::None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct RequestLikeReqwest<'a>(&'a Request);
|
|
|
|
impl<'a> http_cache_semantics::RequestLike for RequestLikeReqwest<'a> {
|
|
fn uri(&self) -> http::uri::Uri {
|
|
// This converts from a url::Url (as returned by reqwest::Request::url)
|
|
// to a http::uri::Uri. The conversion requires parsing, but this is
|
|
// only called ~once per HTTP request. We can afford it.
|
|
self.0
|
|
.url()
|
|
.as_str()
|
|
.parse()
|
|
.expect("reqwest::Request::url always returns a valid URL")
|
|
}
|
|
fn is_same_uri(&self, other: &http::uri::Uri) -> bool {
|
|
// At time of writing, I saw no way to cheaply compare a http::uri::Uri
|
|
// with a url::Url. We can at least avoid parsing anything, and
|
|
// Url::as_str() is free. In practice though, this routine is called
|
|
// ~once per HTTP request. We can afford it. (And it looks like
|
|
// http::uri::Uri's PartialEq<str> implementation has been tuned.)
|
|
self.0.url().as_str() == *other
|
|
}
|
|
fn method(&self) -> &http::method::Method {
|
|
self.0.method()
|
|
}
|
|
fn headers(&self) -> &http::header::HeaderMap {
|
|
self.0.headers()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ResponseLikeReqwest<'a>(&'a Response);
|
|
|
|
impl<'a> http_cache_semantics::ResponseLike for ResponseLikeReqwest<'a> {
|
|
fn status(&self) -> http::status::StatusCode {
|
|
self.0.status()
|
|
}
|
|
fn headers(&self) -> &http::header::HeaderMap {
|
|
self.0.headers()
|
|
}
|
|
}
|