diff --git a/crates/puffin-client/src/cached_client.rs b/crates/puffin-client/src/cached_client.rs index 5c06c4dc0..f98e3fdbb 100644 --- a/crates/puffin-client/src/cached_client.rs +++ b/crates/puffin-client/src/cached_client.rs @@ -2,12 +2,14 @@ use std::future::Future; use std::time::SystemTime; use futures::FutureExt; +use http::request::Parts; use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy}; -use reqwest::{Request, Response}; +use reqwest::{Body, Request, Response}; use reqwest_middleware::ClientWithMiddleware; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tracing::{debug, info_span, instrument, trace, warn, Instrument}; +use url::Url; use puffin_cache::{CacheEntry, Freshness}; use puffin_fs::write_atomic; @@ -115,33 +117,9 @@ impl CachedClient { ) -> Result> where Callback: FnOnce(Response) -> CallbackReturn, - CallbackReturn: Future>, + CallbackReturn: Future> + Send, { - let read_span = info_span!("read_cache", file = %cache_entry.path().display()); - let read_result = fs_err::tokio::read(cache_entry.path()) - .instrument(read_span) - .await; - let cached = if let Ok(cached) = read_result { - let parse_span = info_span!( - "parse_cache", - path = %cache_entry.path().display() - ); - let parse_result = parse_span - .in_scope(|| rmp_serde::from_slice::>(&cached)); - match parse_result { - Ok(data) => Some(data), - Err(err) => { - warn!( - "Broken cache entry at {}, removing: {err}", - cache_entry.path().display() - ); - let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; - None - } - } - } else { - None - }; + let cached = Self::read_cache(cache_entry).await; let cached_response = self.send_cached(req, cache_control, cached).boxed().await?; @@ -165,6 +143,7 @@ impl CachedClient { let immutable = headers.is_immutable(); let data = response_callback(res) + .boxed() .await .map_err(|err| CachedClientError::Callback(err))?; if let Some(cache_policy) = cache_policy { @@ -193,10 +172,41 @@ impl CachedClient { } } + async fn read_cache( + cache_entry: &CacheEntry, + ) -> Option> { + let read_span = info_span!("read_cache", file = %cache_entry.path().display()); + let read_result = fs_err::tokio::read(cache_entry.path()) + .instrument(read_span) + .await; + + if let Ok(cached) = read_result { + let parse_span = info_span!( + "parse_cache", + path = %cache_entry.path().display() + ); + let parse_result = parse_span + .in_scope(|| rmp_serde::from_slice::>(&cached)); + match parse_result { + Ok(data) => Some(data), + Err(err) => { + warn!( + "Broken cache entry at {}, removing: {err}", + cache_entry.path().display() + ); + let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; + None + } + } + } else { + None + } + } + /// `http-cache-semantics` to `reqwest` wrapper async fn send_cached( &self, - mut req: Request, + req: Request, cache_control: CacheControl, cached: Option>, ) -> Result, Error> { @@ -236,60 +246,15 @@ impl CachedClient { CachedResponse::FreshCache(cached.data) } BeforeRequest::Stale { request, matches } => { - if !matches { - // This shouldn't happen; if it does, we'll override the cache. - warn!("Cached request doesn't match current request for: {url}"); - return self.fresh_request(req, converted_req).await; - } - - debug!("Sending revalidation request for: {url}"); - for header in &request.headers { - req.headers_mut().insert(header.0.clone(), header.1.clone()); - converted_req - .headers_mut() - .insert(header.0.clone(), header.1.clone()); - } - let res = self - .0 - .execute(req) - .instrument(info_span!("revalidation_request", url = url.as_str())) - .await - .map_err(ErrorKind::RequestMiddlewareError)? - .error_for_status() - .map_err(ErrorKind::RequestError)?; - let mut converted_res = http::Response::new(()); - *converted_res.status_mut() = res.status(); - for header in res.headers() { - converted_res.headers_mut().insert( - http::HeaderName::from(header.0), - http::HeaderValue::from(header.1), - ); - } - let after_response = cached.cache_policy.after_response( - &converted_req, - &converted_res, - SystemTime::now(), - ); - match after_response { - AfterResponse::NotModified(new_policy, _parts) => { - debug!("Found not-modified response for: {url}"); - let headers = - CacheHeaders::from_response(res.headers().get_all("cache-control")); - let immutable = headers.is_immutable(); - CachedResponse::NotModified(DataWithCachePolicy { - data: cached.data, - immutable, - cache_policy: Box::new(new_policy), - }) - } - AfterResponse::Modified(new_policy, _parts) => { - debug!("Found modified response for: {url}"); - CachedResponse::ModifiedOrNew( - res, - new_policy.is_storable().then(|| Box::new(new_policy)), - ) - } - } + self.send_cached_handle_stale( + req, + converted_req, + url, + cached, + &request, + matches, + ) + .await? } } } else { @@ -299,6 +264,69 @@ impl CachedClient { Ok(cached_response) } + async fn send_cached_handle_stale( + &self, + mut req: Request, + mut converted_req: http::Request, + url: Url, + cached: DataWithCachePolicy, + request: &Parts, + matches: bool, + ) -> Result, Error> { + if !matches { + // This shouldn't happen; if it does, we'll override the cache. + warn!("Cached request doesn't match current request for: {url}"); + return self.fresh_request(req, converted_req).await; + } + + debug!("Sending revalidation request for: {url}"); + for header in &request.headers { + req.headers_mut().insert(header.0.clone(), header.1.clone()); + converted_req + .headers_mut() + .insert(header.0.clone(), header.1.clone()); + } + let res = self + .0 + .execute(req) + .instrument(info_span!("revalidation_request", url = url.as_str())) + .await + .map_err(ErrorKind::RequestMiddlewareError)? + .error_for_status() + .map_err(ErrorKind::RequestError)?; + let mut converted_res = http::Response::new(()); + *converted_res.status_mut() = res.status(); + for header in res.headers() { + converted_res.headers_mut().insert( + http::HeaderName::from(header.0), + http::HeaderValue::from(header.1), + ); + } + let after_response = + cached + .cache_policy + .after_response(&converted_req, &converted_res, SystemTime::now()); + match after_response { + AfterResponse::NotModified(new_policy, _parts) => { + debug!("Found not-modified response for: {url}"); + let headers = CacheHeaders::from_response(res.headers().get_all("cache-control")); + let immutable = headers.is_immutable(); + Ok(CachedResponse::NotModified(DataWithCachePolicy { + data: cached.data, + immutable, + cache_policy: Box::new(new_policy), + })) + } + AfterResponse::Modified(new_policy, _parts) => { + debug!("Found modified response for: {url}"); + Ok(CachedResponse::ModifiedOrNew( + res, + new_policy.is_storable().then(|| Box::new(new_policy)), + )) + } + } + } + #[instrument(skip_all, fields(url = req.url().as_str()))] async fn fresh_request( &self,