diff --git a/crates/puffin-client/src/cached_client.rs b/crates/puffin-client/src/cached_client.rs index 4c82ae31b..c853f52ac 100644 --- a/crates/puffin-client/src/cached_client.rs +++ b/crates/puffin-client/src/cached_client.rs @@ -1,3 +1,6 @@ +#![allow(warnings)] + +use std::fmt::Debug; use std::future::Future; use std::time::SystemTime; @@ -14,7 +17,62 @@ use url::Url; use puffin_cache::{CacheEntry, Freshness}; use puffin_fs::write_atomic; -use crate::{cache_headers::CacheHeaders, Error, ErrorKind}; +use crate::{cache_headers::CacheHeaders, rkyvutil::OwnedArchive, Error, ErrorKind}; + +pub trait Cacheable: Sized + Send { + type Target; + + fn from_bytes(bytes: Vec) -> Result; + fn to_bytes(&self) -> Result, crate::Error>; + fn into_target(self) -> Self::Target; +} + +/// A wrapper type that makes anything with Serde support automatically +/// implement Cacheable. +#[derive(Debug, Deserialize, Serialize)] +#[serde(transparent)] +pub struct SerdeCacheable { + inner: T, +} + +impl Cacheable for SerdeCacheable { + type Target = T; + + fn from_bytes(bytes: Vec) -> Result { + Ok(rmp_serde::from_slice::(&bytes).map_err(ErrorKind::Decode)?) + } + + fn to_bytes(&self) -> Result, Error> { + Ok(rmp_serde::to_vec(&self.inner).map_err(ErrorKind::Encode)?) + } + + fn into_target(self) -> Self::Target { + self.inner + } +} + +impl Cacheable for OwnedArchive +where + A: rkyv::Archive + rkyv::Serialize> + Send, + A::Archived: for<'a> rkyv::CheckBytes> + + rkyv::Deserialize, +{ + type Target = OwnedArchive; + + fn from_bytes(bytes: Vec) -> Result, Error> { + let mut aligned = rkyv::util::AlignedVec::new(); + aligned.extend_from_slice(&bytes); + OwnedArchive::new(aligned) + } + + fn to_bytes(&self) -> Result, Error> { + Ok(OwnedArchive::as_bytes(self).to_vec()) + } + + fn into_target(self) -> Self::Target { + self + } +} /// Either a cached client error or a (user specified) error from the callback #[derive(Debug)] @@ -45,27 +103,126 @@ impl> From> for Error { } #[derive(Debug)] -enum CachedResponse { +enum CachedResponse { /// The cached response is fresh without an HTTP request (e.g. immutable) - FreshCache(Payload), + FreshCache(Vec), /// The cached response is fresh after an HTTP request (e.g. 304 not modified) - NotModified(DataWithCachePolicy), + NotModified(DataWithCachePolicy), /// There was no prior cached response or the cache was outdated /// /// The cache policy is `None` if it isn't storable - ModifiedOrNew(Response, Option>), + // ModifiedOrNew(Response, Option>), + ModifiedOrNew(Response, Option>), } /// Serialize the actual payload together with its caching information. #[derive(Debug, Deserialize, Serialize)] -pub struct DataWithCachePolicy { - pub data: Payload, +pub struct DataWithCachePolicy { + pub data: Vec, /// Whether the response should be considered immutable. immutable: bool, /// The [`CachePolicy`] is used to determine if the response is fresh or stale. /// The policy is large (448 bytes at time of writing), so we reduce the stack size by /// boxing it. - cache_policy: Box, + cache_policy: Box, +} + +#[derive(Debug)] +struct CachePolicyStub(Option); + +impl CachePolicyStub { + fn is_stale(&self, time: SystemTime) -> bool { + self.0.as_ref().map_or(false, |p| p.is_stale(time)) + } + + fn is_storable(&self) -> bool { + self.0.as_ref().map_or(false, |p| p.is_storable()) + } + + fn before_request( + &self, + req: &Req, + now: SystemTime, + ) -> BeforeRequestStub { + match self.0.as_ref() { + None => { + let dummy = http::Response::new(()).into_parts().0; + BeforeRequestStub::Fresh(dummy) + } + Some(p) => p.before_request(req, now).into(), + } + } + + fn after_response< + Req: http_cache_semantics::RequestLike, + Resp: http_cache_semantics::ResponseLike, + >( + &self, + req: &Req, + resp: &Resp, + time: SystemTime, + ) -> AfterResponseStub { + match self.0.as_ref() { + None => unreachable!("oops"), + Some(p) => p.after_response(req, resp, time).into(), + } + } +} + +impl<'de> Deserialize<'de> for CachePolicyStub { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let p = CachePolicy::deserialize(deserializer)?; + Ok(CachePolicyStub(Some(p))) + } +} + +impl Serialize for CachePolicyStub { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + self.0.as_ref().unwrap().serialize(serializer) + } +} + +enum BeforeRequestStub { + Fresh(http::response::Parts), + Stale { + request: http::request::Parts, + matches: bool, + }, +} + +impl From for BeforeRequestStub { + fn from(br: BeforeRequest) -> BeforeRequestStub { + match br { + BeforeRequest::Fresh(parts) => BeforeRequestStub::Fresh(parts), + BeforeRequest::Stale { request, matches } => { + BeforeRequestStub::Stale { request, matches } + } + } + } +} + +enum AfterResponseStub { + NotModified(CachePolicyStub, http::response::Parts), + Modified(CachePolicyStub, http::response::Parts), +} + +impl From for AfterResponseStub { + fn from(ar: AfterResponse) -> AfterResponseStub { + match ar { + AfterResponse::NotModified(p, parts) => { + AfterResponseStub::NotModified(CachePolicyStub(Some(p)), parts) + } + AfterResponse::Modified(p, parts) => { + AfterResponseStub::Modified(CachePolicyStub(Some(p)), parts) + } + } + } } /// Custom caching layer over [`reqwest::Client`] using `http-cache-semantics`. @@ -115,25 +272,57 @@ impl CachedClient { cache_control: CacheControl, response_callback: Callback, ) -> Result> + where + Callback: FnOnce(Response) -> CallbackReturn + Send, + CallbackReturn: Future> + Send, + { + let payload = self + .get_cached_with_callback2(req, cache_entry, cache_control, move |resp| async { + let payload = response_callback(resp).await?; + Ok(SerdeCacheable { inner: payload }) + }) + .await?; + Ok(payload) + } + + #[instrument(skip_all)] + pub async fn get_cached_with_callback2< + Payload: Cacheable, + CallBackError, + Callback, + CallbackReturn, + >( + &self, + req: Request, + cache_entry: &CacheEntry, + cache_control: CacheControl, + response_callback: Callback, + ) -> Result> where Callback: FnOnce(Response) -> CallbackReturn, CallbackReturn: Future> + Send, { - let cached = Self::read_cache(cache_entry).await; + let cached = Self::read_cache(&req, cache_entry).await; let cached_response = self.send_cached(req, cache_control, cached).boxed().await?; let write_cache = info_span!("write_cache", file = %cache_entry.path().display()); match cached_response { - CachedResponse::FreshCache(data) => Ok(data), + CachedResponse::FreshCache(data) => Ok(Payload::from_bytes(data)?), CachedResponse::NotModified(data_with_cache_policy) => { async { - let data = - rmp_serde::to_vec(&data_with_cache_policy).map_err(ErrorKind::Encode)?; - write_atomic(cache_entry.path(), data) - .await - .map_err(ErrorKind::CacheWrite)?; - Ok(data_with_cache_policy.data) + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") { + write_atomic(cache_entry.path(), &data_with_cache_policy.data) + .await + .map_err(ErrorKind::CacheWrite)?; + } else { + let data = rmp_serde::to_vec(&data_with_cache_policy) + .map_err(ErrorKind::Encode)?; + write_atomic(cache_entry.path(), &data) + .await + .map_err(ErrorKind::CacheWrite)?; + } + Ok(Payload::from_bytes(data_with_cache_policy.data)?) } .instrument(write_cache) .await @@ -148,7 +337,7 @@ impl CachedClient { .map_err(|err| CachedClientError::Callback(err))?; if let Some(cache_policy) = cache_policy { let data_with_cache_policy = DataWithCachePolicy { - data, + data: data.to_bytes()?, immutable, cache_policy, }; @@ -156,25 +345,29 @@ impl CachedClient { fs_err::tokio::create_dir_all(cache_entry.dir()) .await .map_err(ErrorKind::CacheWrite)?; - let data = rmp_serde::to_vec(&data_with_cache_policy) - .map_err(ErrorKind::Encode)?; - write_atomic(cache_entry.path(), data) - .await - .map_err(ErrorKind::CacheWrite)?; - Ok(data_with_cache_policy.data) + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") { + write_atomic(cache_entry.path(), &data_with_cache_policy.data) + .await + .map_err(ErrorKind::CacheWrite)?; + } else { + let envelope = rmp_serde::to_vec(&data_with_cache_policy) + .map_err(ErrorKind::Encode)?; + write_atomic(cache_entry.path(), envelope) + .await + .map_err(ErrorKind::CacheWrite)?; + } + Ok(data.into_target()) } .instrument(write_cache) .await } else { - Ok(data) + Ok(data.into_target()) } } } } - async fn read_cache( - cache_entry: &CacheEntry, - ) -> Option> { + async fn read_cache(req: &Request, cache_entry: &CacheEntry) -> Option { let read_span = info_span!("read_cache", file = %cache_entry.path().display()); let read_result = fs_err::tokio::read(cache_entry.path()) .instrument(read_span) @@ -185,21 +378,28 @@ impl CachedClient { "parse_cache", path = %cache_entry.path().display() ); - let parse_result = tokio::task::spawn_blocking(move || { - parse_span - .in_scope(|| rmp_serde::from_slice::>(&cached)) - }) - .await - .expect("Tokio executor failed, was there a panic?"); - match parse_result { - Ok(data) => Some(data), - Err(err) => { - warn!( - "Broken cache entry at {}, removing: {err}", - cache_entry.path().display() - ); - let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; - None + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") { + Some(DataWithCachePolicy { + data: cached, + immutable: req.url().as_str().contains("pypi.org"), + cache_policy: Box::new(CachePolicyStub(None)), + }) + } else { + let parse_result = tokio::task::spawn_blocking(move || { + parse_span.in_scope(|| rmp_serde::from_slice::(&cached)) + }) + .await + .expect("Tokio executor failed, was there a panic?"); + match parse_result { + Ok(data) => Some(data), + Err(err) => { + warn!( + "Broken cache entry at {}, removing: {err}", + cache_entry.path().display() + ); + let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; + None + } } } } else { @@ -208,12 +408,17 @@ impl CachedClient { } /// `http-cache-semantics` to `reqwest` wrapper - async fn send_cached( + async fn send_cached( &self, mut req: Request, cache_control: CacheControl, - cached: Option>, - ) -> Result, Error> { + cached: Option, + ) -> Result { + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") && cached.is_some() + { + return Ok(CachedResponse::FreshCache(cached.expect("wat").data)); + } + let url = req.url().clone(); let cached_response = if let Some(cached) = cached { // Avoid sending revalidation requests for immutable responses. @@ -237,11 +442,11 @@ impl CachedClient { .cache_policy .before_request(&RequestLikeReqwest(&req), SystemTime::now()) { - BeforeRequest::Fresh(_) => { + BeforeRequestStub::Fresh(_) => { debug!("Found fresh response for: {url}"); CachedResponse::FreshCache(cached.data) } - BeforeRequest::Stale { request, matches } => { + BeforeRequestStub::Stale { request, matches } => { self.send_cached_handle_stale(req, url, cached, &request, matches) .await? } @@ -253,14 +458,14 @@ impl CachedClient { Ok(cached_response) } - async fn send_cached_handle_stale( + async fn send_cached_handle_stale( &self, mut req: Request, url: Url, - cached: DataWithCachePolicy, + cached: DataWithCachePolicy, request: &Parts, matches: bool, - ) -> Result, Error> { + ) -> Result { if !matches { // This shouldn't happen; if it does, we'll override the cache. warn!("Cached request doesn't match current request for: {url}"); @@ -285,7 +490,7 @@ impl CachedClient { SystemTime::now(), ); match after_response { - AfterResponse::NotModified(new_policy, _parts) => { + AfterResponseStub::NotModified(new_policy, _parts) => { debug!("Found not-modified response for: {url}"); let headers = CacheHeaders::from_response(res.headers().get_all("cache-control")); let immutable = headers.is_immutable(); @@ -295,7 +500,7 @@ impl CachedClient { cache_policy: Box::new(new_policy), })) } - AfterResponse::Modified(new_policy, _parts) => { + AfterResponseStub::Modified(new_policy, _parts) => { debug!("Found modified response for: {url}"); Ok(CachedResponse::ModifiedOrNew( res, @@ -306,7 +511,7 @@ impl CachedClient { } #[instrument(skip_all, fields(url = req.url().as_str()))] - async fn fresh_request(&self, req: Request) -> Result, Error> { + async fn fresh_request(&self, req: Request) -> Result { trace!("{} {}", req.method(), req.url()); let res = self .0 @@ -315,7 +520,10 @@ impl CachedClient { .map_err(ErrorKind::RequestMiddlewareError)? .error_for_status() .map_err(ErrorKind::RequestError)?; - let cache_policy = CachePolicy::new(&RequestLikeReqwest(&req), &ResponseLikeReqwest(&res)); + let cache_policy = CachePolicyStub(Some(CachePolicy::new( + &RequestLikeReqwest(&req), + &ResponseLikeReqwest(&res), + ))); Ok(CachedResponse::ModifiedOrNew( res, cache_policy.is_storable().then(|| Box::new(cache_policy)), diff --git a/crates/puffin-client/src/lib.rs b/crates/puffin-client/src/lib.rs index ebf20d587..8a014c298 100644 --- a/crates/puffin-client/src/lib.rs +++ b/crates/puffin-client/src/lib.rs @@ -1,4 +1,6 @@ -pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; +pub use cached_client::{ + CacheControl, Cacheable, CachedClient, CachedClientError, DataWithCachePolicy, SerdeCacheable, +}; pub use error::{Error, ErrorKind}; pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError}; pub use registry_client::{ diff --git a/crates/puffin-client/src/registry_client.rs b/crates/puffin-client/src/registry_client.rs index 6c77a96ed..1c56ef2d3 100644 --- a/crates/puffin-client/src/registry_client.rs +++ b/crates/puffin-client/src/registry_client.rs @@ -209,7 +209,7 @@ impl RegistryClient { .map_err(|err| Error::from_json_err(err, url.clone()))?; let metadata = SimpleMetadata::from_files(data.files, package_name, url.as_str()); - Ok(metadata) + Ok(metadata.to_archive().unwrap()) } MediaType::Html => { let text = response.text().await.map_err(ErrorKind::RequestError)?; @@ -217,7 +217,7 @@ impl RegistryClient { .map_err(|err| Error::from_html_err(err, url.clone()))?; let metadata = SimpleMetadata::from_files(files, package_name, base.as_url().as_str()); - Ok(metadata) + Ok(metadata.to_archive().unwrap()) } } } @@ -226,14 +226,14 @@ impl RegistryClient { }; let result = self .client - .get_cached_with_callback( + .get_cached_with_callback2( simple_request, &cache_entry, cache_control, parse_simple_response, ) .await; - Ok(result.map(|simple| simple.to_archive().unwrap())) + Ok(result) } /// Fetch the metadata for a remote wheel file. diff --git a/crates/puffin-distribution/src/source/mod.rs b/crates/puffin-distribution/src/source/mod.rs index d1ad59103..8fefefb86 100644 --- a/crates/puffin-distribution/src/source/mod.rs +++ b/crates/puffin-distribution/src/source/mod.rs @@ -942,9 +942,11 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> { /// Read an existing HTTP-cached [`Manifest`], if it exists. pub(crate) fn read_http_manifest(cache_entry: &CacheEntry) -> Result, Error> { match std::fs::read(cache_entry.path()) { - Ok(cached) => Ok(Some( - rmp_serde::from_slice::>(&cached)?.data, - )), + Ok(cached) => { + let raw = rmp_serde::from_slice::(&cached)?; + let manifest = rmp_serde::from_slice::(&raw.data)?; + Ok(Some(manifest)) + } Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None), Err(err) => Err(Error::CacheRead(err)), }