use std::future::Future; use std::time::SystemTime; use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy}; use reqwest::{Request, Response}; use reqwest_middleware::ClientWithMiddleware; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tracing::{debug, trace, warn}; use puffin_cache::CacheEntry; use puffin_fs::write_atomic; /// Either a cached client error or a (user specified) error from the callback pub enum CachedClientError { Client(crate::Error), Callback(CallbackError), } impl From for CachedClientError { fn from(error: crate::Error) -> Self { CachedClientError::Client(error) } } impl From> for crate::Error { fn from(error: CachedClientError) -> crate::Error { match error { CachedClientError::Client(error) => error, CachedClientError::Callback(error) => error, } } } #[derive(Debug)] enum CachedResponse { /// The cached response is fresh without an HTTP request (e.g. immutable) FreshCache(Payload), /// The cached response is fresh after an HTTP request (e.g. 304 not modified) NotModified(DataWithCachePolicy), /// There was no prior cached response or the cache was outdated /// /// The cache policy is `None` if it isn't storable ModifiedOrNew(Response, Option), } /// Serialize the actual payload together with its caching information #[derive(Debug, Deserialize, Serialize)] pub struct DataWithCachePolicy { pub data: Payload, cache_policy: CachePolicy, } /// Custom caching layer over [`reqwest::Client`] using `http-cache-semantics`. /// /// The implementation takes inspiration from the `http-cache` crate, but adds support for running /// an async callback on the response before caching. We use this to e.g. store a /// parsed version of the wheel metadata and for our remote zip reader. In the latter case, we want /// to read a single file from a remote zip using range requests (so we don't have to download the /// entire file). We send a HEAD request in the caching layer to check if the remote file has /// changed (and if range requests are supported), and in the callback we make the actual range /// requests if required. /// /// Unlike `http-cache`, all outputs must be serde-able. Currently everything is json, but we can /// transparently switch to a faster/smaller format. /// /// Again unlike `http-cache`, the caller gets full control over the cache key with the assumption /// that it's a file. #[derive(Debug, Clone)] pub struct CachedClient(ClientWithMiddleware); impl CachedClient { pub fn new(client: ClientWithMiddleware) -> Self { Self(client) } /// The middleware is the retry strategy pub fn uncached(&self) -> ClientWithMiddleware { self.0.clone() } /// Make a cached request with a custom response transformation /// /// If a new response was received (no prior cached response or modified on the remote), the /// response is passed through `response_callback` and only the result is cached and returned. /// The `response_callback` is allowed to make subsequent requests, e.g. through the uncached /// client. pub async fn get_cached_with_callback< Payload: Serialize + DeserializeOwned, CallBackError, Callback, CallbackReturn, >( &self, req: Request, cache_entry: &CacheEntry, response_callback: Callback, ) -> Result> where Callback: FnOnce(Response) -> CallbackReturn, CallbackReturn: Future>, { let cached = if let Ok(cached) = fs_err::tokio::read(cache_entry.path()).await { match serde_json::from_slice::>(&cached) { Ok(data) => Some(data), Err(err) => { warn!( "Broken cache entry at {}, removing: {err}", cache_entry.path().display() ); let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; None } } } else { None }; let cached_response = self.send_cached(req, cached).await?; match cached_response { CachedResponse::FreshCache(data) => Ok(data), CachedResponse::NotModified(data_with_cache_policy) => { write_atomic( cache_entry.path(), serde_json::to_vec(&data_with_cache_policy).map_err(crate::Error::from)?, ) .await .map_err(crate::Error::CacheWrite)?; Ok(data_with_cache_policy.data) } CachedResponse::ModifiedOrNew(res, cache_policy) => { let data = response_callback(res) .await .map_err(|err| CachedClientError::Callback(err))?; if let Some(cache_policy) = cache_policy { let data_with_cache_policy = DataWithCachePolicy { data, cache_policy }; fs_err::tokio::create_dir_all(&cache_entry.dir) .await .map_err(crate::Error::CacheWrite)?; let data = serde_json::to_vec(&data_with_cache_policy).map_err(crate::Error::from)?; write_atomic(cache_entry.path(), data) .await .map_err(crate::Error::CacheWrite)?; Ok(data_with_cache_policy.data) } else { Ok(data) } } } } /// `http-cache-semantics` to `reqwest` wrapper async fn send_cached( &self, mut req: Request, cached: Option>, ) -> Result, crate::Error> { // The converted types are from the specific `reqwest` types to the more generic `http` // types let mut converted_req = http::Request::try_from( req.try_clone() .expect("You can't use streaming request bodies with this function"), )?; let url = req.url().clone(); let cached_response = if let Some(cached) = cached { match cached .cache_policy .before_request(&converted_req, SystemTime::now()) { BeforeRequest::Fresh(_) => { debug!("Found fresh response for: {url}"); CachedResponse::FreshCache(cached.data) } BeforeRequest::Stale { request, matches } => { if !matches { // This shouldn't happen; if it does, we'll override the cache. warn!("Cached request doesn't match current request for: {url}"); return self.fresh_request(req, converted_req).await; } debug!("Sending revalidation request for: {url}"); for header in &request.headers { req.headers_mut().insert(header.0.clone(), header.1.clone()); converted_req .headers_mut() .insert(header.0.clone(), header.1.clone()); } let res = self.0.execute(req).await?.error_for_status()?; let mut converted_res = http::Response::new(()); *converted_res.status_mut() = res.status(); for header in res.headers() { converted_res.headers_mut().insert( http::HeaderName::from(header.0), http::HeaderValue::from(header.1), ); } let after_response = cached.cache_policy.after_response( &converted_req, &converted_res, SystemTime::now(), ); match after_response { AfterResponse::NotModified(new_policy, _parts) => { debug!("Found not-modified response for: {url}"); CachedResponse::NotModified(DataWithCachePolicy { data: cached.data, cache_policy: new_policy, }) } AfterResponse::Modified(new_policy, _parts) => { debug!("Found modified response for: {url}"); CachedResponse::ModifiedOrNew( res, new_policy.is_storable().then_some(new_policy), ) } } } } } else { debug!("No cache entry for: {url}"); self.fresh_request(req, converted_req).await? }; Ok(cached_response) } async fn fresh_request( &self, req: Request, converted_req: http::Request, ) -> Result, crate::Error> { trace!("{} {}", req.method(), req.url()); let res = self.0.execute(req).await?.error_for_status()?; let mut converted_res = http::Response::new(()); *converted_res.status_mut() = res.status(); for header in res.headers() { converted_res.headers_mut().insert( http::HeaderName::from(header.0), http::HeaderValue::from(header.1), ); } let cache_policy = CachePolicy::new(&converted_req.into_parts().0, &converted_res.into_parts().0); Ok(CachedResponse::ModifiedOrNew( res, cache_policy.is_storable().then_some(cache_policy), )) } }