hack: make rkyv work by ignoring DataWithCachePolicy

The DataWithCachePolicy poses some challenges for use with rkyv.
In the interest of getting some kind of measurement with rkyv,
this commit hacks around DataWithCachePolicy by stubbing it out.

This commit isn't meant to be merged and will likely be
completely thrown out.
This commit is contained in:
Andrew Gallant 2024-01-26 20:35:45 -05:00
parent 8d70055140
commit c7cafff493
No known key found for this signature in database
GPG Key ID: 5518C8B38E0693E0
4 changed files with 274 additions and 62 deletions

View File

@ -1,3 +1,6 @@
#![allow(warnings)]
use std::fmt::Debug;
use std::future::Future; use std::future::Future;
use std::time::SystemTime; use std::time::SystemTime;
@ -14,7 +17,62 @@ use url::Url;
use puffin_cache::{CacheEntry, Freshness}; use puffin_cache::{CacheEntry, Freshness};
use puffin_fs::write_atomic; use puffin_fs::write_atomic;
use crate::{cache_headers::CacheHeaders, Error, ErrorKind}; use crate::{cache_headers::CacheHeaders, rkyvutil::OwnedArchive, Error, ErrorKind};
pub trait Cacheable: Sized + Send {
type Target;
fn from_bytes(bytes: Vec<u8>) -> Result<Self::Target, crate::Error>;
fn to_bytes(&self) -> Result<Vec<u8>, crate::Error>;
fn into_target(self) -> Self::Target;
}
/// A wrapper type that makes anything with Serde support automatically
/// implement Cacheable.
#[derive(Debug, Deserialize, Serialize)]
#[serde(transparent)]
pub struct SerdeCacheable<T> {
inner: T,
}
impl<T: Send + Serialize + DeserializeOwned> Cacheable for SerdeCacheable<T> {
type Target = T;
fn from_bytes(bytes: Vec<u8>) -> Result<T, Error> {
Ok(rmp_serde::from_slice::<T>(&bytes).map_err(ErrorKind::Decode)?)
}
fn to_bytes(&self) -> Result<Vec<u8>, Error> {
Ok(rmp_serde::to_vec(&self.inner).map_err(ErrorKind::Encode)?)
}
fn into_target(self) -> Self::Target {
self.inner
}
}
impl<A> Cacheable for OwnedArchive<A>
where
A: rkyv::Archive + rkyv::Serialize<crate::rkyvutil::Serializer<4096>> + Send,
A::Archived: for<'a> rkyv::CheckBytes<rkyv::validation::validators::DefaultValidator<'a>>
+ rkyv::Deserialize<A, rkyv::de::deserializers::SharedDeserializeMap>,
{
type Target = OwnedArchive<A>;
fn from_bytes(bytes: Vec<u8>) -> Result<OwnedArchive<A>, Error> {
let mut aligned = rkyv::util::AlignedVec::new();
aligned.extend_from_slice(&bytes);
OwnedArchive::new(aligned)
}
fn to_bytes(&self) -> Result<Vec<u8>, Error> {
Ok(OwnedArchive::as_bytes(self).to_vec())
}
fn into_target(self) -> Self::Target {
self
}
}
/// Either a cached client error or a (user specified) error from the callback /// Either a cached client error or a (user specified) error from the callback
#[derive(Debug)] #[derive(Debug)]
@ -45,27 +103,126 @@ impl<E: Into<Error>> From<CachedClientError<E>> for Error {
} }
#[derive(Debug)] #[derive(Debug)]
enum CachedResponse<Payload: Serialize> { enum CachedResponse {
/// The cached response is fresh without an HTTP request (e.g. immutable) /// The cached response is fresh without an HTTP request (e.g. immutable)
FreshCache(Payload), FreshCache(Vec<u8>),
/// The cached response is fresh after an HTTP request (e.g. 304 not modified) /// The cached response is fresh after an HTTP request (e.g. 304 not modified)
NotModified(DataWithCachePolicy<Payload>), NotModified(DataWithCachePolicy),
/// There was no prior cached response or the cache was outdated /// There was no prior cached response or the cache was outdated
/// ///
/// The cache policy is `None` if it isn't storable /// The cache policy is `None` if it isn't storable
ModifiedOrNew(Response, Option<Box<CachePolicy>>), // ModifiedOrNew(Response, Option<Box<CachePolicy>>),
ModifiedOrNew(Response, Option<Box<CachePolicyStub>>),
} }
/// Serialize the actual payload together with its caching information. /// Serialize the actual payload together with its caching information.
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
pub struct DataWithCachePolicy<Payload: Serialize> { pub struct DataWithCachePolicy {
pub data: Payload, pub data: Vec<u8>,
/// Whether the response should be considered immutable. /// Whether the response should be considered immutable.
immutable: bool, immutable: bool,
/// The [`CachePolicy`] is used to determine if the response is fresh or stale. /// The [`CachePolicy`] is used to determine if the response is fresh or stale.
/// The policy is large (448 bytes at time of writing), so we reduce the stack size by /// The policy is large (448 bytes at time of writing), so we reduce the stack size by
/// boxing it. /// boxing it.
cache_policy: Box<CachePolicy>, cache_policy: Box<CachePolicyStub>,
}
#[derive(Debug)]
struct CachePolicyStub(Option<CachePolicy>);
impl CachePolicyStub {
fn is_stale(&self, time: SystemTime) -> bool {
self.0.as_ref().map_or(false, |p| p.is_stale(time))
}
fn is_storable(&self) -> bool {
self.0.as_ref().map_or(false, |p| p.is_storable())
}
fn before_request<Req: http_cache_semantics::RequestLike>(
&self,
req: &Req,
now: SystemTime,
) -> BeforeRequestStub {
match self.0.as_ref() {
None => {
let dummy = http::Response::new(()).into_parts().0;
BeforeRequestStub::Fresh(dummy)
}
Some(p) => p.before_request(req, now).into(),
}
}
fn after_response<
Req: http_cache_semantics::RequestLike,
Resp: http_cache_semantics::ResponseLike,
>(
&self,
req: &Req,
resp: &Resp,
time: SystemTime,
) -> AfterResponseStub {
match self.0.as_ref() {
None => unreachable!("oops"),
Some(p) => p.after_response(req, resp, time).into(),
}
}
}
impl<'de> Deserialize<'de> for CachePolicyStub {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let p = CachePolicy::deserialize(deserializer)?;
Ok(CachePolicyStub(Some(p)))
}
}
impl Serialize for CachePolicyStub {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
self.0.as_ref().unwrap().serialize(serializer)
}
}
enum BeforeRequestStub {
Fresh(http::response::Parts),
Stale {
request: http::request::Parts,
matches: bool,
},
}
impl From<BeforeRequest> for BeforeRequestStub {
fn from(br: BeforeRequest) -> BeforeRequestStub {
match br {
BeforeRequest::Fresh(parts) => BeforeRequestStub::Fresh(parts),
BeforeRequest::Stale { request, matches } => {
BeforeRequestStub::Stale { request, matches }
}
}
}
}
enum AfterResponseStub {
NotModified(CachePolicyStub, http::response::Parts),
Modified(CachePolicyStub, http::response::Parts),
}
impl From<AfterResponse> for AfterResponseStub {
fn from(ar: AfterResponse) -> AfterResponseStub {
match ar {
AfterResponse::NotModified(p, parts) => {
AfterResponseStub::NotModified(CachePolicyStub(Some(p)), parts)
}
AfterResponse::Modified(p, parts) => {
AfterResponseStub::Modified(CachePolicyStub(Some(p)), parts)
}
}
}
} }
/// Custom caching layer over [`reqwest::Client`] using `http-cache-semantics`. /// Custom caching layer over [`reqwest::Client`] using `http-cache-semantics`.
@ -115,25 +272,57 @@ impl CachedClient {
cache_control: CacheControl, cache_control: CacheControl,
response_callback: Callback, response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> ) -> Result<Payload, CachedClientError<CallBackError>>
where
Callback: FnOnce(Response) -> CallbackReturn + Send,
CallbackReturn: Future<Output = Result<Payload, CallBackError>> + Send,
{
let payload = self
.get_cached_with_callback2(req, cache_entry, cache_control, move |resp| async {
let payload = response_callback(resp).await?;
Ok(SerdeCacheable { inner: payload })
})
.await?;
Ok(payload)
}
#[instrument(skip_all)]
pub async fn get_cached_with_callback2<
Payload: Cacheable,
CallBackError,
Callback,
CallbackReturn,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>>
where where
Callback: FnOnce(Response) -> CallbackReturn, Callback: FnOnce(Response) -> CallbackReturn,
CallbackReturn: Future<Output = Result<Payload, CallBackError>> + Send, CallbackReturn: Future<Output = Result<Payload, CallBackError>> + Send,
{ {
let cached = Self::read_cache(cache_entry).await; let cached = Self::read_cache(&req, cache_entry).await;
let cached_response = self.send_cached(req, cache_control, cached).boxed().await?; let cached_response = self.send_cached(req, cache_control, cached).boxed().await?;
let write_cache = info_span!("write_cache", file = %cache_entry.path().display()); let write_cache = info_span!("write_cache", file = %cache_entry.path().display());
match cached_response { match cached_response {
CachedResponse::FreshCache(data) => Ok(data), CachedResponse::FreshCache(data) => Ok(Payload::from_bytes(data)?),
CachedResponse::NotModified(data_with_cache_policy) => { CachedResponse::NotModified(data_with_cache_policy) => {
async { async {
let data = if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") {
rmp_serde::to_vec(&data_with_cache_policy).map_err(ErrorKind::Encode)?; write_atomic(cache_entry.path(), &data_with_cache_policy.data)
write_atomic(cache_entry.path(), data) .await
.await .map_err(ErrorKind::CacheWrite)?;
.map_err(ErrorKind::CacheWrite)?; } else {
Ok(data_with_cache_policy.data) let data = rmp_serde::to_vec(&data_with_cache_policy)
.map_err(ErrorKind::Encode)?;
write_atomic(cache_entry.path(), &data)
.await
.map_err(ErrorKind::CacheWrite)?;
}
Ok(Payload::from_bytes(data_with_cache_policy.data)?)
} }
.instrument(write_cache) .instrument(write_cache)
.await .await
@ -148,7 +337,7 @@ impl CachedClient {
.map_err(|err| CachedClientError::Callback(err))?; .map_err(|err| CachedClientError::Callback(err))?;
if let Some(cache_policy) = cache_policy { if let Some(cache_policy) = cache_policy {
let data_with_cache_policy = DataWithCachePolicy { let data_with_cache_policy = DataWithCachePolicy {
data, data: data.to_bytes()?,
immutable, immutable,
cache_policy, cache_policy,
}; };
@ -156,25 +345,29 @@ impl CachedClient {
fs_err::tokio::create_dir_all(cache_entry.dir()) fs_err::tokio::create_dir_all(cache_entry.dir())
.await .await
.map_err(ErrorKind::CacheWrite)?; .map_err(ErrorKind::CacheWrite)?;
let data = rmp_serde::to_vec(&data_with_cache_policy) if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") {
.map_err(ErrorKind::Encode)?; write_atomic(cache_entry.path(), &data_with_cache_policy.data)
write_atomic(cache_entry.path(), data) .await
.await .map_err(ErrorKind::CacheWrite)?;
.map_err(ErrorKind::CacheWrite)?; } else {
Ok(data_with_cache_policy.data) let envelope = rmp_serde::to_vec(&data_with_cache_policy)
.map_err(ErrorKind::Encode)?;
write_atomic(cache_entry.path(), envelope)
.await
.map_err(ErrorKind::CacheWrite)?;
}
Ok(data.into_target())
} }
.instrument(write_cache) .instrument(write_cache)
.await .await
} else { } else {
Ok(data) Ok(data.into_target())
} }
} }
} }
} }
async fn read_cache<Payload: Serialize + DeserializeOwned + Send + 'static>( async fn read_cache(req: &Request, cache_entry: &CacheEntry) -> Option<DataWithCachePolicy> {
cache_entry: &CacheEntry,
) -> Option<DataWithCachePolicy<Payload>> {
let read_span = info_span!("read_cache", file = %cache_entry.path().display()); let read_span = info_span!("read_cache", file = %cache_entry.path().display());
let read_result = fs_err::tokio::read(cache_entry.path()) let read_result = fs_err::tokio::read(cache_entry.path())
.instrument(read_span) .instrument(read_span)
@ -185,21 +378,28 @@ impl CachedClient {
"parse_cache", "parse_cache",
path = %cache_entry.path().display() path = %cache_entry.path().display()
); );
let parse_result = tokio::task::spawn_blocking(move || { if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") {
parse_span Some(DataWithCachePolicy {
.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy<Payload>>(&cached)) data: cached,
}) immutable: req.url().as_str().contains("pypi.org"),
.await cache_policy: Box::new(CachePolicyStub(None)),
.expect("Tokio executor failed, was there a panic?"); })
match parse_result { } else {
Ok(data) => Some(data), let parse_result = tokio::task::spawn_blocking(move || {
Err(err) => { parse_span.in_scope(|| rmp_serde::from_slice::<DataWithCachePolicy>(&cached))
warn!( })
"Broken cache entry at {}, removing: {err}", .await
cache_entry.path().display() .expect("Tokio executor failed, was there a panic?");
); match parse_result {
let _ = fs_err::tokio::remove_file(&cache_entry.path()).await; Ok(data) => Some(data),
None Err(err) => {
warn!(
"Broken cache entry at {}, removing: {err}",
cache_entry.path().display()
);
let _ = fs_err::tokio::remove_file(&cache_entry.path()).await;
None
}
} }
} }
} else { } else {
@ -208,12 +408,17 @@ impl CachedClient {
} }
/// `http-cache-semantics` to `reqwest` wrapper /// `http-cache-semantics` to `reqwest` wrapper
async fn send_cached<T: Serialize + DeserializeOwned>( async fn send_cached(
&self, &self,
mut req: Request, mut req: Request,
cache_control: CacheControl, cache_control: CacheControl,
cached: Option<DataWithCachePolicy<T>>, cached: Option<DataWithCachePolicy>,
) -> Result<CachedResponse<T>, Error> { ) -> Result<CachedResponse, Error> {
if std::env::var("PUFFIN_STUB_CACHE_POLICY").map_or(false, |v| v == "1") && cached.is_some()
{
return Ok(CachedResponse::FreshCache(cached.expect("wat").data));
}
let url = req.url().clone(); let url = req.url().clone();
let cached_response = if let Some(cached) = cached { let cached_response = if let Some(cached) = cached {
// Avoid sending revalidation requests for immutable responses. // Avoid sending revalidation requests for immutable responses.
@ -237,11 +442,11 @@ impl CachedClient {
.cache_policy .cache_policy
.before_request(&RequestLikeReqwest(&req), SystemTime::now()) .before_request(&RequestLikeReqwest(&req), SystemTime::now())
{ {
BeforeRequest::Fresh(_) => { BeforeRequestStub::Fresh(_) => {
debug!("Found fresh response for: {url}"); debug!("Found fresh response for: {url}");
CachedResponse::FreshCache(cached.data) CachedResponse::FreshCache(cached.data)
} }
BeforeRequest::Stale { request, matches } => { BeforeRequestStub::Stale { request, matches } => {
self.send_cached_handle_stale(req, url, cached, &request, matches) self.send_cached_handle_stale(req, url, cached, &request, matches)
.await? .await?
} }
@ -253,14 +458,14 @@ impl CachedClient {
Ok(cached_response) Ok(cached_response)
} }
async fn send_cached_handle_stale<T: Serialize + DeserializeOwned>( async fn send_cached_handle_stale(
&self, &self,
mut req: Request, mut req: Request,
url: Url, url: Url,
cached: DataWithCachePolicy<T>, cached: DataWithCachePolicy,
request: &Parts, request: &Parts,
matches: bool, matches: bool,
) -> Result<CachedResponse<T>, Error> { ) -> Result<CachedResponse, Error> {
if !matches { if !matches {
// This shouldn't happen; if it does, we'll override the cache. // This shouldn't happen; if it does, we'll override the cache.
warn!("Cached request doesn't match current request for: {url}"); warn!("Cached request doesn't match current request for: {url}");
@ -285,7 +490,7 @@ impl CachedClient {
SystemTime::now(), SystemTime::now(),
); );
match after_response { match after_response {
AfterResponse::NotModified(new_policy, _parts) => { AfterResponseStub::NotModified(new_policy, _parts) => {
debug!("Found not-modified response for: {url}"); debug!("Found not-modified response for: {url}");
let headers = CacheHeaders::from_response(res.headers().get_all("cache-control")); let headers = CacheHeaders::from_response(res.headers().get_all("cache-control"));
let immutable = headers.is_immutable(); let immutable = headers.is_immutable();
@ -295,7 +500,7 @@ impl CachedClient {
cache_policy: Box::new(new_policy), cache_policy: Box::new(new_policy),
})) }))
} }
AfterResponse::Modified(new_policy, _parts) => { AfterResponseStub::Modified(new_policy, _parts) => {
debug!("Found modified response for: {url}"); debug!("Found modified response for: {url}");
Ok(CachedResponse::ModifiedOrNew( Ok(CachedResponse::ModifiedOrNew(
res, res,
@ -306,7 +511,7 @@ impl CachedClient {
} }
#[instrument(skip_all, fields(url = req.url().as_str()))] #[instrument(skip_all, fields(url = req.url().as_str()))]
async fn fresh_request<T: Serialize>(&self, req: Request) -> Result<CachedResponse<T>, Error> { async fn fresh_request(&self, req: Request) -> Result<CachedResponse, Error> {
trace!("{} {}", req.method(), req.url()); trace!("{} {}", req.method(), req.url());
let res = self let res = self
.0 .0
@ -315,7 +520,10 @@ impl CachedClient {
.map_err(ErrorKind::RequestMiddlewareError)? .map_err(ErrorKind::RequestMiddlewareError)?
.error_for_status() .error_for_status()
.map_err(ErrorKind::RequestError)?; .map_err(ErrorKind::RequestError)?;
let cache_policy = CachePolicy::new(&RequestLikeReqwest(&req), &ResponseLikeReqwest(&res)); let cache_policy = CachePolicyStub(Some(CachePolicy::new(
&RequestLikeReqwest(&req),
&ResponseLikeReqwest(&res),
)));
Ok(CachedResponse::ModifiedOrNew( Ok(CachedResponse::ModifiedOrNew(
res, res,
cache_policy.is_storable().then(|| Box::new(cache_policy)), cache_policy.is_storable().then(|| Box::new(cache_policy)),

View File

@ -1,4 +1,6 @@
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use cached_client::{
CacheControl, Cacheable, CachedClient, CachedClientError, DataWithCachePolicy, SerdeCacheable,
};
pub use error::{Error, ErrorKind}; pub use error::{Error, ErrorKind};
pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError}; pub use flat_index::{FlatDistributions, FlatIndex, FlatIndexClient, FlatIndexError};
pub use registry_client::{ pub use registry_client::{

View File

@ -209,7 +209,7 @@ impl RegistryClient {
.map_err(|err| Error::from_json_err(err, url.clone()))?; .map_err(|err| Error::from_json_err(err, url.clone()))?;
let metadata = let metadata =
SimpleMetadata::from_files(data.files, package_name, url.as_str()); SimpleMetadata::from_files(data.files, package_name, url.as_str());
Ok(metadata) Ok(metadata.to_archive().unwrap())
} }
MediaType::Html => { MediaType::Html => {
let text = response.text().await.map_err(ErrorKind::RequestError)?; let text = response.text().await.map_err(ErrorKind::RequestError)?;
@ -217,7 +217,7 @@ impl RegistryClient {
.map_err(|err| Error::from_html_err(err, url.clone()))?; .map_err(|err| Error::from_html_err(err, url.clone()))?;
let metadata = let metadata =
SimpleMetadata::from_files(files, package_name, base.as_url().as_str()); SimpleMetadata::from_files(files, package_name, base.as_url().as_str());
Ok(metadata) Ok(metadata.to_archive().unwrap())
} }
} }
} }
@ -226,14 +226,14 @@ impl RegistryClient {
}; };
let result = self let result = self
.client .client
.get_cached_with_callback( .get_cached_with_callback2(
simple_request, simple_request,
&cache_entry, &cache_entry,
cache_control, cache_control,
parse_simple_response, parse_simple_response,
) )
.await; .await;
Ok(result.map(|simple| simple.to_archive().unwrap())) Ok(result)
} }
/// Fetch the metadata for a remote wheel file. /// Fetch the metadata for a remote wheel file.

View File

@ -942,9 +942,11 @@ impl<'a, T: BuildContext> SourceDistCachedBuilder<'a, T> {
/// Read an existing HTTP-cached [`Manifest`], if it exists. /// Read an existing HTTP-cached [`Manifest`], if it exists.
pub(crate) fn read_http_manifest(cache_entry: &CacheEntry) -> Result<Option<Manifest>, Error> { pub(crate) fn read_http_manifest(cache_entry: &CacheEntry) -> Result<Option<Manifest>, Error> {
match std::fs::read(cache_entry.path()) { match std::fs::read(cache_entry.path()) {
Ok(cached) => Ok(Some( Ok(cached) => {
rmp_serde::from_slice::<DataWithCachePolicy<Manifest>>(&cached)?.data, let raw = rmp_serde::from_slice::<DataWithCachePolicy>(&cached)?;
)), let manifest = rmp_serde::from_slice::<Manifest>(&raw.data)?;
Ok(Some(manifest))
}
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None), Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(Error::CacheRead(err)), Err(err) => Err(Error::CacheRead(err)),
} }