From c7cafff4933e2337ffbc3ab1f55d5d2826707f80 Mon Sep 17 00:00:00 2001 From: Andrew Gallant Date: Fri, 26 Jan 2024 20:35:45 -0500 Subject: [PATCH] hack: make rkyv work by ignoring DataWithCachePolicy The DataWithCachePolicy poses some challenges for use with rkyv. In the interest of getting some kind of measurement with rkyv, this commit hacks around DataWithCachePolicy by stubbing it out. This commit isn't meant to be merged and will likely be completely thrown out. --- crates/puffin-client/src/cached_client.rs | 316 +++++++++++++++---- crates/puffin-client/src/lib.rs | 4 +- crates/puffin-client/src/registry_client.rs | 8 +- crates/puffin-distribution/src/source/mod.rs | 8 +- 4 files changed, 274 insertions(+), 62 deletions(-) diff --git a/crates/puffin-client/src/cached_client.rs b/crates/puffin-client/src/cached_client.rs index 4c82ae31b..c853f52ac 100644 --- a/crates/puffin-client/src/cached_client.rs +++ b/crates/puffin-client/src/cached_client.rs @@ -1,3 +1,6 @@ +#![allow(warnings)] + +use std::fmt::Debug; use std::future::Future; use std::time::SystemTime; @@ -14,7 +17,62 @@ use url::Url; use puffin_cache::{CacheEntry, Freshness}; use puffin_fs::write_atomic; -use crate::{cache_headers::CacheHeaders, Error, ErrorKind}; +use crate::{cache_headers::CacheHeaders, rkyvutil::OwnedArchive, Error, ErrorKind}; + +pub trait Cacheable: Sized + Send { + type Target; + + fn from_bytes(bytes: Vec) -> Result; + fn to_bytes(&self) -> Result, crate::Error>; + fn into_target(self) -> Self::Target; +} + +/// A wrapper type that makes anything with Serde support automatically +/// implement Cacheable. +#[derive(Debug, Deserialize, Serialize)] +#[serde(transparent)] +pub struct SerdeCacheable { + inner: T, +} + +impl Cacheable for SerdeCacheable { + type Target = T; + + fn from_bytes(bytes: Vec) -> Result { + Ok(rmp_serde::from_slice::(&bytes).map_err(ErrorKind::Decode)?) + } + + fn to_bytes(&self) -> Result, Error> { + Ok(rmp_serde::to_vec(&self.inner).map_err(ErrorKind::Encode)?) + } + + fn into_target(self) -> Self::Target { + self.inner + } +} + +impl Cacheable for OwnedArchive +where + A: rkyv::Archive + rkyv::Serialize> + Send, + A::Archived: for<'a> rkyv::CheckBytes> + + rkyv::Deserialize, +{ + type Target = OwnedArchive; + + fn from_bytes(bytes: Vec) -> Result, Error> { + let mut aligned = rkyv::util::AlignedVec::new(); + aligned.extend_from_slice(&bytes); + OwnedArchive::new(aligned) + } + + fn to_bytes(&self) -> Result, Error> { + Ok(OwnedArchive::as_bytes(self).to_vec()) + } + + fn into_target(self) -> Self::Target { + self + } +} /// Either a cached client error or a (user specified) error from the callback #[derive(Debug)] @@ -45,27 +103,126 @@ impl> From> for Error { } #[derive(Debug)] -enum CachedResponse { +enum CachedResponse { /// The cached response is fresh without an HTTP request (e.g. immutable) - FreshCache(Payload), + FreshCache(Vec), /// The cached response is fresh after an HTTP request (e.g. 304 not modified) - NotModified(DataWithCachePolicy), + NotModified(DataWithCachePolicy), /// There was no prior cached response or the cache was outdated /// /// The cache policy is `None` if it isn't storable - ModifiedOrNew(Response, Option>), + // ModifiedOrNew(Response, Option>), + ModifiedOrNew(Response, Option>), } /// Serialize the actual payload together with its caching information. #[derive(Debug, Deserialize, Serialize)] -pub struct DataWithCachePolicy { - pub data: Payload, +pub struct DataWithCachePolicy { + pub data: Vec, /// Whether the response should be considered immutable. immutable: bool, /// The [`CachePolicy`] is used to determine if the response is fresh or stale. /// The policy is large (448 bytes at time of writing), so we reduce the stack size by /// boxing it. - cache_policy: Box, + cache_policy: Box, +} + +#[derive(Debug)] +struct CachePolicyStub(Option); + +impl CachePolicyStub { + fn is_stale(&self, time: SystemTime) -> bool { + self.0.as_ref().map_or(false, |p| p.is_stale(time)) + } + + fn is_storable(&self) -> bool { + self.0.as_ref().map_or(false, |p| p.is_storable()) + } + + fn before_request( + &self, + req: &Req, + now: SystemTime, + ) -> BeforeRequestStub { + match self.0.as_ref() { + None => { + let dummy = http::Response::new(()).into_parts().0; + BeforeRequestStub::Fresh(dummy) + } + Some(p) => p.before_request(req, now).into(), + } + } + + fn after_response< + Req: http_cache_semantics::RequestLike, + Resp: http_cache_semantics::ResponseLike, + >( + &self, + req: &Req, + resp: &Resp, + time: SystemTime, + ) -> AfterResponseStub { + match self.0.as_ref() { + None => unreachable!("oops"), + Some(p) => p.after_response(req, resp, time).into(), + } + } +} + +impl<'de> Deserialize<'de> for CachePolicyStub { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let p = CachePolicy::deserialize(deserializer)?; + Ok(CachePolicyStub(Some(p))) + } +} + +impl Serialize for CachePolicyStub { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + self.0.as_ref().unwrap().serialize(serializer) + } +} + +enum BeforeRequestStub { + Fresh(http::response::Parts), + Stale { + request: http::request::Parts, + matches: bool, + }, +} + +impl From for BeforeRequestStub { + fn from(br: BeforeRequest) -> BeforeRequestStub { + match br { + BeforeRequest::Fresh(parts) => BeforeRequestStub::Fresh(parts), + BeforeRequest::Stale { request, matches } => { + BeforeRequestStub::Stale { request, matches } + } + } + } +} + +enum AfterResponseStub { + NotModified(CachePolicyStub, http::response::Parts), + Modified(CachePolicyStub, http::response::Parts), +} + +impl From for AfterResponseStub { + fn from(ar: AfterResponse) -> AfterResponseStub { + match ar { + AfterResponse::NotModified(p, parts) => { + AfterResponseStub::NotModified(CachePolicyStub(Some(p)), parts) + } + AfterResponse::Modified(p, parts) => { + AfterResponseStub::Modified(CachePolicyStub(Some(p)), parts) + } + } + } } /// Custom caching layer over [`reqwest::Client`] using `http-cache-semantics`. @@ -115,25 +272,57 @@ impl CachedClient { cache_control: CacheControl, response_callback: Callback, ) -> Result> + where + Callback: FnOnce(Response) -> CallbackReturn + Send, + CallbackReturn: Future> + Send, + { + let payload = self + .get_cached_with_callback2(req, cache_entry, cache_control, move |resp| async { + let payload = response_callback(resp).await?; + Ok(SerdeCacheable { inner: payload }) + }) + .await?; + Ok(payload) + } + + #[instrument(skip_all)] + pub async fn get_cached_with_callback2< + Payload: Cacheable, + CallBackError, + Callback, + CallbackReturn, + >( + &self, + req: Request, + cache_entry: &CacheEntry, + cache_control: CacheControl, + response_callback: Callback, + ) -> Result> where Callback: FnOnce(Response) -> CallbackReturn, CallbackReturn: Future> + Send, { - let cached = Self::read_cache(cache_entry).await; + let cached = Self::read_cache(&req, cache_entry).await; let cached_response = self.send_cached(req, cache_control, cached).boxed().await?; let write_cache = info_span!("write_cache", file = %cache_entry.path().display()); match cached_response { - CachedResponse::FreshCache(data) => Ok(data), + CachedResponse::FreshCache(data) => Ok(Payload::from_bytes(data)?), CachedResponse::NotModified(data_with_cache_policy) => { async { - let data = - rmp_serde::to_vec(&data_with_cache_policy).map_err(ErrorKind::Encode)?; - write_atomic(cache_entry.path(), data) - .await - .map_err(ErrorKind::CacheWrite)?; - Ok(data_with_cache_policy.data) + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") { + write_atomic(cache_entry.path(), &data_with_cache_policy.data) + .await + .map_err(ErrorKind::CacheWrite)?; + } else { + let data = rmp_serde::to_vec(&data_with_cache_policy) + .map_err(ErrorKind::Encode)?; + write_atomic(cache_entry.path(), &data) + .await + .map_err(ErrorKind::CacheWrite)?; + } + Ok(Payload::from_bytes(data_with_cache_policy.data)?) } .instrument(write_cache) .await @@ -148,7 +337,7 @@ impl CachedClient { .map_err(|err| CachedClientError::Callback(err))?; if let Some(cache_policy) = cache_policy { let data_with_cache_policy = DataWithCachePolicy { - data, + data: data.to_bytes()?, immutable, cache_policy, }; @@ -156,25 +345,29 @@ impl CachedClient { fs_err::tokio::create_dir_all(cache_entry.dir()) .await .map_err(ErrorKind::CacheWrite)?; - let data = rmp_serde::to_vec(&data_with_cache_policy) - .map_err(ErrorKind::Encode)?; - write_atomic(cache_entry.path(), data) - .await - .map_err(ErrorKind::CacheWrite)?; - Ok(data_with_cache_policy.data) + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") { + write_atomic(cache_entry.path(), &data_with_cache_policy.data) + .await + .map_err(ErrorKind::CacheWrite)?; + } else { + let envelope = rmp_serde::to_vec(&data_with_cache_policy) + .map_err(ErrorKind::Encode)?; + write_atomic(cache_entry.path(), envelope) + .await + .map_err(ErrorKind::CacheWrite)?; + } + Ok(data.into_target()) } .instrument(write_cache) .await } else { - Ok(data) + Ok(data.into_target()) } } } } - async fn read_cache( - cache_entry: &CacheEntry, - ) -> Option> { + async fn read_cache(req: &Request, cache_entry: &CacheEntry) -> Option { let read_span = info_span!("read_cache", file = %cache_entry.path().display()); let read_result = fs_err::tokio::read(cache_entry.path()) .instrument(read_span) @@ -185,21 +378,28 @@ impl CachedClient { "parse_cache", path = %cache_entry.path().display() ); - let parse_result = tokio::task::spawn_blocking(move || { - parse_span - .in_scope(|| rmp_serde::from_slice::>(&cached)) - }) - .await - .expect("Tokio executor failed, was there a panic?"); - match parse_result { - Ok(data) => Some(data), - Err(err) => { - warn!( - "Broken cache entry at {}, removing: {err}", - cache_entry.path().display() - ); - let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; - None + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") { + Some(DataWithCachePolicy { + data: cached, + immutable: req.url().as_str().contains("pypi.org"), + cache_policy: Box::new(CachePolicyStub(None)), + }) + } else { + let parse_result = tokio::task::spawn_blocking(move || { + parse_span.in_scope(|| rmp_serde::from_slice::(&cached)) + }) + .await + .expect("Tokio executor failed, was there a panic?"); + match parse_result { + Ok(data) => Some(data), + Err(err) => { + warn!( + "Broken cache entry at {}, removing: {err}", + cache_entry.path().display() + ); + let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; + None + } } } } else { @@ -208,12 +408,17 @@ impl CachedClient { } /// `http-cache-semantics` to `reqwest` wrapper - async fn send_cached( + async fn send_cached( &self, mut req: Request, cache_control: CacheControl, - cached: Option>, - ) -> Result, Error> { + cached: Option, + ) -> Result { + if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") && cached.is_some() + { + return Ok(CachedResponse::FreshCache(cached.expect("wat").data)); + } + let url = req.url().clone(); let cached_response = if let Some(cached) = cached { // Avoid sending revalidation requests for immutable responses. @@ -237,11 +442,11 @@ impl CachedClient { .cache_policy .before_request(&RequestLikeReqwest(&req), SystemTime::now()) { - BeforeRequest::Fresh(_) => { + BeforeRequestStub::Fresh(_) => { debug!("Found fresh response for: {url}"); CachedResponse::FreshCache(cached.data) } - BeforeRequest::Stale { request, matches } => { + BeforeRequestStub::Stale { request, matches } => { self.send_cached_handle_stale(req, url, cached, &request, matches) .await? } @@ -253,14 +458,14 @@ impl CachedClient { Ok(cached_response) } - async fn send_cached_handle_stale( + async fn send_cached_handle_stale( &self, mut req: Request, url: Url, - cached: DataWithCachePolicy, + cached: DataWithCachePolicy, request: &Parts, matches: bool, - ) -> Result, Error> { + ) -> Result { if !matches { // This shouldn't happen; if it does, we'll override the cache. warn!("Cached request doesn't match current request for: {url}"); @@ -285,7 +490,7 @@ impl CachedClient { SystemTime::now(), ); match after_response { - AfterResponse::NotModified(new_policy, _parts) => { + AfterResponseStub::NotModified(new_policy, _parts) => { debug!("Found not-modified response for: {url}"); let headers = CacheHeaders::from_response(res.headers().get_all("cache-control")); let immutable = headers.is_immutable(); @@ -295,7 +500,7 @@ impl CachedClient { cache_policy: Box::new(new_policy), })) } - AfterResponse::Modified(new_policy, _parts) => { + AfterResponseStub::Modified(new_policy, _parts) => { debug!("Found modified response for: {url}"); Ok(CachedResponse::ModifiedOrNew( res, @@ -306,7 +511,7 @@ impl CachedClient { } #[instrument(skip_all, fields(url = req.url().as_str()))] - async fn fresh_request(&self, req: Request) -> Result, Error> { + async fn fresh_request(&self, req: Request) -> Result { trace!("{} {}", req.method(), req.url()); let res = self .0 @@ -315,7 +520,10 @@ impl CachedClient { .map_err(ErrorKind::RequestMiddlewareError)? .error_for_status() .map_err(ErrorKind::RequestError)?; - let cache_policy = CachePolicy::new(&RequestLikeReqwest(&req), &ResponseLikeReqwest(&res)); + let cache_policy = CachePolicyStub(Some(CachePolicy::new( + &RequestLikeReqwest(&req), + &ResponseLikeReqwest(&res), + ))); Ok(CachedResponse::ModifiedOrNew( res, cache_policy.is_storable().then(|| Box::new(cache_policy)), diff --git a/crates/puffin-client/src/lib.rs b/crates/puffin-client/src/lib.rs index ebf20d587..8a014c298 100644 --- a/crates/puffin-client/src/lib.rs +++ b/crates/puffin-client/src/lib.rs @@ -1,4 +1,6 @@ -pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; +pub use cached_client::{ + CacheControl, Cacheable, CachedClient, CachedClientError, DataWithCachePolicy, SerdeCacheable, +}; pub use error::{Error, ErrorKind}; pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError}; pub use registry_client::{ diff --git a/crates/puffin-client/src/registry_client.rs b/crates/puffin-client/src/registry_client.rs index 6c77a96ed..1c56ef2d3 100644 --- a/crates/puffin-client/src/registry_client.rs +++ b/crates/puffin-client/src/registry_client.rs @@ -209,7 +209,7 @@ impl RegistryClient { .map_err(|err| Error::from_json_err(err, url.clone()))?; let metadata = SimpleMetadata::from_files(data.files, package_name, url.as_str()); - Ok(metadata) + Ok(metadata.to_archive().unwrap()) } MediaType::Html => { let text = response.text().await.map_err(ErrorKind::RequestError)?; @@ -217,7 +217,7 @@ impl RegistryClient { .map_err(|err| Error::from_html_err(err, url.clone()))?; let metadata = SimpleMetadata::from_files(files, package_name, base.as_url().as_str()); - Ok(metadata) + Ok(metadata.to_archive().unwrap()) } } } @@ -226,14 +226,14 @@ impl RegistryClient { }; let result = self .client - .get_cached_with_callback( + .get_cached_with_callback2( simple_request, &cache_entry, cache_control, parse_simple_response, ) .await; - Ok(result.map(|simple| simple.to_archive().unwrap())) + Ok(result) } /// Fetch the metadata for a remote wheel file. diff --git a/crates/puffin-distribution/src/source/mod.rs b/crates/puffin-distribution/src/source/mod.rs index d1ad59103..8fefefb86 100644 --- a/crates/puffin-distribution/src/source/mod.rs +++ b/crates/puffin-distribution/src/source/mod.rs @@ -942,9 +942,11 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> { /// Read an existing HTTP-cached [`Manifest`], if it exists. pub(crate) fn read_http_manifest(cache_entry: &CacheEntry) -> Result, Error> { match std::fs::read(cache_entry.path()) { - Ok(cached) => Ok(Some( - rmp_serde::from_slice::>(&cached)?.data, - )), + Ok(cached) => { + let raw = rmp_serde::from_slice::(&cached)?; + let manifest = rmp_serde::from_slice::(&raw.data)?; + Ok(Some(manifest)) + } Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None), Err(err) => Err(Error::CacheRead(err)), }