uv/crates/uv-client/src/cached_client.rs

985 lines
38 KiB
Rust

use std::time::{Duration, SystemTime};
use std::{borrow::Cow, path::Path};
use futures::FutureExt;
use reqwest::{Request, Response};
use reqwest_retry::RetryPolicy;
use rkyv::util::AlignedVec;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tracing::{Instrument, debug, info_span, instrument, trace, warn};
use uv_cache::{CacheEntry, Freshness};
use uv_fs::write_atomic;
use uv_redacted::DisplaySafeUrl;
use crate::BaseClient;
use crate::base_client::is_extended_transient_error;
use crate::{
Error, ErrorKind,
httpcache::{AfterResponse, BeforeRequest, CachePolicy, CachePolicyBuilder},
rkyvutil::OwnedArchive,
};
/// A trait the generalizes (de)serialization at a high level.
///
/// The main purpose of this trait is to make the `CachedClient` work for
/// either serde or other mechanisms of serialization such as `rkyv`.
///
/// If you're using Serde, then unless you want to control the format, callers
/// should just use `CachedClient::get_serde`. This will use a default
/// implementation of `Cacheable` internally.
///
/// Alternatively, callers using `rkyv` should use
/// `CachedClient::get_cacheable`. If your types fit into the
/// `rkyvutil::OwnedArchive` mold, then an implementation of `Cacheable` is
/// already provided for that type.
pub trait Cacheable: Sized {
/// This associated type permits customizing what the "output" type of
/// deserialization is. It can be identical to `Self`.
///
/// Typical use of this is for wrapper types used to provide blanket trait
/// impls without hitting overlapping impl problems.
type Target: Send + 'static;
/// Deserialize a value from bytes aligned to a 16-byte boundary.
fn from_aligned_bytes(bytes: AlignedVec) -> Result<Self::Target, Error>;
/// Serialize bytes to a possibly owned byte buffer.
fn to_bytes(&self) -> Result<Cow<'_, [u8]>, Error>;
/// Convert this type into its final form.
fn into_target(self) -> Self::Target;
}
/// A wrapper type that makes anything with Serde support automatically
/// implement `Cacheable`.
#[derive(Debug, Deserialize, Serialize)]
#[serde(transparent)]
pub(crate) struct SerdeCacheable<T> {
inner: T,
}
impl<T: Serialize + DeserializeOwned + Send + 'static> Cacheable for SerdeCacheable<T> {
type Target = T;
fn from_aligned_bytes(bytes: AlignedVec) -> Result<T, Error> {
Ok(rmp_serde::from_slice::<T>(&bytes).map_err(ErrorKind::Decode)?)
}
fn to_bytes(&self) -> Result<Cow<'_, [u8]>, Error> {
Ok(Cow::from(
rmp_serde::to_vec(&self.inner).map_err(ErrorKind::Encode)?,
))
}
fn into_target(self) -> Self::Target {
self.inner
}
}
/// All `OwnedArchive` values are cacheable.
impl<A> Cacheable for OwnedArchive<A>
where
A: rkyv::Archive + for<'a> rkyv::Serialize<crate::rkyvutil::Serializer<'a>> + Send + 'static,
A::Archived: rkyv::Portable
+ rkyv::Deserialize<A, crate::rkyvutil::Deserializer>
+ for<'a> rkyv::bytecheck::CheckBytes<crate::rkyvutil::Validator<'a>>,
{
type Target = Self;
fn from_aligned_bytes(bytes: AlignedVec) -> Result<Self, Error> {
Self::new(bytes)
}
fn to_bytes(&self) -> Result<Cow<'_, [u8]>, Error> {
Ok(Cow::from(Self::as_bytes(self)))
}
fn into_target(self) -> Self::Target {
self
}
}
/// Dispatch type: Either a cached client error or a (user specified) error from the callback
pub enum CachedClientError<CallbackError: std::error::Error + 'static> {
Client {
retries: Option<u32>,
err: Error,
},
Callback {
retries: Option<u32>,
err: CallbackError,
},
}
impl<CallbackError: std::error::Error + 'static> CachedClientError<CallbackError> {
/// Attach the number of retries to the error context.
///
/// Adds to existing errors if any, in case different layers retried.
fn with_retries(self, retries: u32) -> Self {
match self {
CachedClientError::Client {
retries: existing_retries,
err,
} => CachedClientError::Client {
retries: Some(existing_retries.unwrap_or_default() + retries),
err,
},
CachedClientError::Callback {
retries: existing_retries,
err,
} => CachedClientError::Callback {
retries: Some(existing_retries.unwrap_or_default() + retries),
err,
},
}
}
fn retries(&self) -> Option<u32> {
match self {
CachedClientError::Client { retries, .. } => *retries,
CachedClientError::Callback { retries, .. } => *retries,
}
}
fn error(&self) -> &dyn std::error::Error {
match self {
CachedClientError::Client { err, .. } => err,
CachedClientError::Callback { err, .. } => err,
}
}
}
impl<CallbackError: std::error::Error + 'static> From<Error> for CachedClientError<CallbackError> {
fn from(error: Error) -> Self {
Self::Client {
retries: None,
err: error,
}
}
}
impl<CallbackError: std::error::Error + 'static> From<ErrorKind>
for CachedClientError<CallbackError>
{
fn from(error: ErrorKind) -> Self {
Self::Client {
retries: None,
err: error.into(),
}
}
}
impl<E: Into<Self> + std::error::Error + 'static> From<CachedClientError<E>> for Error {
/// Attach retry error context, if there were retries.
fn from(error: CachedClientError<E>) -> Self {
match error {
CachedClientError::Client {
retries: Some(retries),
err,
} => ErrorKind::RequestWithRetries {
source: Box::new(err.into_kind()),
retries,
}
.into(),
CachedClientError::Client { retries: None, err } => err,
CachedClientError::Callback {
retries: Some(retries),
err,
} => ErrorKind::RequestWithRetries {
source: Box::new(err.into().into_kind()),
retries,
}
.into(),
CachedClientError::Callback { retries: None, err } => err.into(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum CacheControl {
/// Respect the `cache-control` header from the response.
None,
/// Apply `max-age=0, must-revalidate` to the request.
MustRevalidate,
/// Allow the client to return stale responses.
AllowStale,
}
impl From<Freshness> for CacheControl {
fn from(value: Freshness) -> Self {
match value {
Freshness::Fresh => Self::None,
Freshness::Stale => Self::MustRevalidate,
Freshness::Missing => Self::None,
}
}
}
/// Custom caching layer over [`reqwest::Client`].
///
/// The implementation takes inspiration from the `http-cache` crate, but adds support for running
/// an async callback on the response before caching. We use this to e.g. store a
/// parsed version of the wheel metadata and for our remote zip reader. In the latter case, we want
/// to read a single file from a remote zip using range requests (so we don't have to download the
/// entire file). We send a HEAD request in the caching layer to check if the remote file has
/// changed (and if range requests are supported), and in the callback we make the actual range
/// requests if required.
///
/// Unlike `http-cache`, all outputs must be serializable/deserializable in some way, by
/// implementing the `Cacheable` trait.
///
/// Again unlike `http-cache`, the caller gets full control over the cache key with the assumption
/// that it's a file.
#[derive(Debug, Clone)]
pub struct CachedClient(BaseClient);
impl CachedClient {
pub fn new(client: BaseClient) -> Self {
Self(client)
}
/// The underlying [`BaseClient`] without caching.
pub fn uncached(&self) -> &BaseClient {
&self.0
}
/// Make a cached request with a custom response transformation
/// while using serde to (de)serialize cached responses.
///
/// If a new response was received (no prior cached response or modified
/// on the remote), the response is passed through `response_callback` and
/// only the result is cached and returned. The `response_callback` is
/// allowed to make subsequent requests, e.g. through the uncached client.
#[instrument(skip_all)]
pub async fn get_serde<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let payload = self
.get_cacheable(req, cache_entry, cache_control, async |resp| {
let payload = response_callback(resp).await?;
Ok(SerdeCacheable { inner: payload })
})
.await?;
Ok(payload)
}
/// Make a cached request with a custom response transformation while using
/// the `Cacheable` trait to (de)serialize cached responses.
///
/// The purpose of this routine is the use of `Cacheable`. Namely, it
/// generalizes over (de)serialization such that mechanisms other than
/// serde (such as rkyv) can be used to manage (de)serialization of cached
/// data.
///
/// If a new response was received (no prior cached response or modified
/// on the remote), the response is passed through `response_callback` and
/// only the result is cached and returned. The `response_callback` is
/// allowed to make subsequent requests, e.g. through the uncached client.
#[instrument(skip_all)]
pub async fn get_cacheable<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let fresh_req = req.try_clone().expect("HTTP request must be cloneable");
let cached_response = if let Some(cached) = Self::read_cache(cache_entry).await {
self.send_cached(req, cache_control, cached)
.boxed_local()
.await?
} else {
debug!("No cache entry for: {}", req.url());
let (response, cache_policy) = self.fresh_request(req).await?;
CachedResponse::ModifiedOrNew {
response,
cache_policy,
}
};
match cached_response {
CachedResponse::FreshCache(cached) => match Payload::from_aligned_bytes(cached.data) {
Ok(payload) => Ok(payload),
Err(err) => {
warn!(
"Broken fresh cache entry (for payload) at {}, removing: {err}",
cache_entry.path().display()
);
self.resend_and_heal_cache(fresh_req, cache_entry, response_callback)
.await
}
},
CachedResponse::NotModified { cached, new_policy } => {
let refresh_cache =
info_span!("refresh_cache", file = %cache_entry.path().display());
async {
let data_with_cache_policy_bytes =
DataWithCachePolicy::serialize(&new_policy, &cached.data)?;
write_atomic(cache_entry.path(), data_with_cache_policy_bytes)
.await
.map_err(ErrorKind::CacheWrite)?;
match Payload::from_aligned_bytes(cached.data) {
Ok(payload) => Ok(payload),
Err(err) => {
warn!(
"Broken fresh cache entry after revalidation \
(for payload) at {}, removing: {err}",
cache_entry.path().display()
);
self.resend_and_heal_cache(fresh_req, cache_entry, response_callback)
.await
}
}
}
.instrument(refresh_cache)
.await
}
CachedResponse::ModifiedOrNew {
response,
cache_policy,
} => {
// If we got a modified response, but it's a 304, then a validator failed (e.g., the
// ETag didn't match). We need to make a fresh request.
if response.status() == http::StatusCode::NOT_MODIFIED {
warn!("Server returned unusable 304 for: {}", fresh_req.url());
self.resend_and_heal_cache(fresh_req, cache_entry, response_callback)
.await
} else {
self.run_response_callback(
cache_entry,
cache_policy,
response,
response_callback,
)
.await
}
}
}
}
/// Make a request without checking whether the cache is fresh.
pub async fn skip_cache<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFnOnce(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let (response, cache_policy) = self.fresh_request(req).await?;
let payload = self
.run_response_callback(cache_entry, cache_policy, response, async |resp| {
let payload = response_callback(resp).await?;
Ok(SerdeCacheable { inner: payload })
})
.await?;
Ok(payload)
}
async fn resend_and_heal_cache<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFnOnce(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let _ = fs_err::tokio::remove_file(&cache_entry.path()).await;
let (response, cache_policy) = self.fresh_request(req).await?;
self.run_response_callback(cache_entry, cache_policy, response, response_callback)
.await
}
async fn run_response_callback<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFnOnce(Response) -> Result<Payload, CallBackError>,
>(
&self,
cache_entry: &CacheEntry,
cache_policy: Option<Box<CachePolicy>>,
response: Response,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let new_cache = info_span!("new_cache", file = %cache_entry.path().display());
let data = response_callback(response)
.boxed_local()
.await
.map_err(|err| CachedClientError::Callback { retries: None, err })?;
let Some(cache_policy) = cache_policy else {
return Ok(data.into_target());
};
async {
fs_err::tokio::create_dir_all(cache_entry.dir())
.await
.map_err(ErrorKind::CacheWrite)?;
let data_with_cache_policy_bytes =
DataWithCachePolicy::serialize(&cache_policy, &data.to_bytes()?)?;
write_atomic(cache_entry.path(), data_with_cache_policy_bytes)
.await
.map_err(ErrorKind::CacheWrite)?;
Ok(data.into_target())
}
.instrument(new_cache)
.await
}
#[instrument(name="read_and_parse_cache", skip_all, fields(file = %cache_entry.path().display()))]
async fn read_cache(cache_entry: &CacheEntry) -> Option<DataWithCachePolicy> {
match DataWithCachePolicy::from_path_async(cache_entry.path()).await {
Ok(data) => Some(data),
Err(err) => {
// When we know the cache entry doesn't exist, then things are
// normal and we shouldn't emit a WARN.
if err.is_file_not_exists() {
trace!("No cache entry exists for {}", cache_entry.path().display());
} else {
warn!(
"Broken cache policy entry at {}, removing: {err}",
cache_entry.path().display()
);
let _ = fs_err::tokio::remove_file(&cache_entry.path()).await;
}
None
}
}
}
/// Send a request given that we have a (possibly) stale cached response.
///
/// If the cached response is valid but stale, then this will attempt a
/// revalidation request.
async fn send_cached(
&self,
mut req: Request,
cache_control: CacheControl,
cached: DataWithCachePolicy,
) -> Result<CachedResponse, Error> {
// Apply the cache control header, if necessary.
match cache_control {
CacheControl::None | CacheControl::AllowStale => {}
CacheControl::MustRevalidate => {
req.headers_mut().insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_static("no-cache"),
);
}
}
Ok(match cached.cache_policy.before_request(&mut req) {
BeforeRequest::Fresh => {
debug!("Found fresh response for: {}", req.url());
CachedResponse::FreshCache(cached)
}
BeforeRequest::Stale(new_cache_policy_builder) => match cache_control {
CacheControl::None | CacheControl::MustRevalidate => {
debug!("Found stale response for: {}", req.url());
self.send_cached_handle_stale(req, cached, new_cache_policy_builder)
.await?
}
CacheControl::AllowStale => {
debug!("Found stale (but allowed) response for: {}", req.url());
CachedResponse::FreshCache(cached)
}
},
BeforeRequest::NoMatch => {
// This shouldn't happen; if it does, we'll override the cache.
warn!(
"Cached request doesn't match current request for: {}",
req.url()
);
let (response, cache_policy) = self.fresh_request(req).await?;
CachedResponse::ModifiedOrNew {
response,
cache_policy,
}
}
})
}
async fn send_cached_handle_stale(
&self,
req: Request,
cached: DataWithCachePolicy,
new_cache_policy_builder: CachePolicyBuilder,
) -> Result<CachedResponse, Error> {
let url = DisplaySafeUrl::from(req.url().clone());
debug!("Sending revalidation request for: {url}");
let response = self
.0
.execute(req)
.instrument(info_span!("revalidation_request", url = url.as_str()))
.await
.map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?
.error_for_status()
.map_err(|err| ErrorKind::from_reqwest(url.clone(), err))?;
match cached
.cache_policy
.after_response(new_cache_policy_builder, &response)
{
AfterResponse::NotModified(new_policy) => {
debug!("Found not-modified response for: {url}");
Ok(CachedResponse::NotModified {
cached,
new_policy: Box::new(new_policy),
})
}
AfterResponse::Modified(new_policy) => {
debug!("Found modified response for: {url}");
Ok(CachedResponse::ModifiedOrNew {
response,
cache_policy: new_policy
.to_archived()
.is_storable()
.then(|| Box::new(new_policy)),
})
}
}
}
#[instrument(skip_all, fields(url = req.url().as_str()))]
async fn fresh_request(
&self,
req: Request,
) -> Result<(Response, Option<Box<CachePolicy>>), Error> {
let url = DisplaySafeUrl::from(req.url().clone());
trace!("Sending fresh {} request for {}", req.method(), url);
let cache_policy_builder = CachePolicyBuilder::new(&req);
let response = self
.0
.execute(req)
.await
.map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?;
let retry_count = response
.extensions()
.get::<reqwest_retry::RetryCount>()
.map(|retries| retries.value());
if let Err(status_error) = response.error_for_status_ref() {
return Err(CachedClientError::<Error>::Client {
retries: retry_count,
err: ErrorKind::from_reqwest(url, status_error).into(),
}
.into());
}
let cache_policy = cache_policy_builder.build(&response);
let cache_policy = if cache_policy.to_archived().is_storable() {
Some(Box::new(cache_policy))
} else {
None
};
Ok((response, cache_policy))
}
/// Perform a [`CachedClient::get_serde`] request with a default retry strategy.
#[instrument(skip_all)]
pub async fn get_serde_with_retry<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let payload = self
.get_cacheable_with_retry(req, cache_entry, cache_control, async |resp| {
let payload = response_callback(resp).await?;
Ok(SerdeCacheable { inner: payload })
})
.await?;
Ok(payload)
}
/// Perform a [`CachedClient::get_cacheable`] request with a default retry strategy.
///
/// See: <https://github.com/TrueLayer/reqwest-middleware/blob/8a494c165734e24c62823714843e1c9347027e8a/reqwest-retry/src/middleware.rs#L137>
#[instrument(skip_all)]
pub async fn get_cacheable_with_retry<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let mut past_retries = 0;
let start_time = SystemTime::now();
let retry_policy = self.uncached().retry_policy();
loop {
let fresh_req = req.try_clone().expect("HTTP request must be cloneable");
let result = self
.get_cacheable(fresh_req, cache_entry, cache_control, &response_callback)
.await;
// Check if the middleware already performed retries
let middleware_retries = match &result {
Err(err) => err.retries().unwrap_or_default(),
Ok(_) => 0,
};
if result
.as_ref()
.is_err_and(|err| is_extended_transient_error(err.error()))
{
// If middleware already retried, consider that in our retry budget
let total_retries = past_retries + middleware_retries;
let retry_decision = retry_policy.should_retry(start_time, total_retries);
if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision {
debug!(
"Transient failure while handling response from {}; retrying...",
req.url(),
);
let duration = execute_after
.duration_since(SystemTime::now())
.unwrap_or_else(|_| Duration::default());
tokio::time::sleep(duration).await;
past_retries += 1;
continue;
}
}
if past_retries > 0 {
return result.map_err(|err| err.with_retries(past_retries));
}
return result;
}
}
/// Perform a [`CachedClient::skip_cache`] request with a default retry strategy.
///
/// See: <https://github.com/TrueLayer/reqwest-middleware/blob/8a494c165734e24c62823714843e1c9347027e8a/reqwest-retry/src/middleware.rs#L137>
pub async fn skip_cache_with_retry<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let mut past_retries = 0;
let start_time = SystemTime::now();
let retry_policy = self.uncached().retry_policy();
loop {
let fresh_req = req.try_clone().expect("HTTP request must be cloneable");
let result = self
.skip_cache(fresh_req, cache_entry, &response_callback)
.await;
// Check if the middleware already performed retries
let middleware_retries = match &result {
Err(err) => err.retries().unwrap_or_default(),
_ => 0,
};
if result
.as_ref()
.err()
.is_some_and(|err| is_extended_transient_error(err.error()))
{
let total_retries = past_retries + middleware_retries;
let retry_decision = retry_policy.should_retry(start_time, total_retries);
if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision {
debug!(
"Transient failure while handling response from {}; retrying...",
req.url(),
);
let duration = execute_after
.duration_since(SystemTime::now())
.unwrap_or_else(|_| Duration::default());
tokio::time::sleep(duration).await;
past_retries += 1;
continue;
}
}
if past_retries > 0 {
return result.map_err(|err| err.with_retries(past_retries));
}
return result;
}
}
}
#[derive(Debug)]
enum CachedResponse {
/// The cached response is fresh without an HTTP request (e.g. age < max-age).
FreshCache(DataWithCachePolicy),
/// The cached response is fresh after an HTTP request (e.g. 304 not modified)
NotModified {
/// The cached response (with its old cache policy).
cached: DataWithCachePolicy,
/// The new [`CachePolicy`] is used to determine if the response
/// is fresh or stale when making subsequent requests for the same
/// resource. This policy should overwrite the old policy associated
/// with the cached response. In particular, this new policy is derived
/// from data received in a revalidation response, which might change
/// the parameters of cache behavior.
///
/// The policy is large (352 bytes at time of writing), so we reduce
/// the stack size by boxing it.
new_policy: Box<CachePolicy>,
},
/// There was no prior cached response or the cache was outdated
///
/// The cache policy is `None` if it isn't storable
ModifiedOrNew {
/// The response received from the server.
response: Response,
/// The [`CachePolicy`] is used to determine if the response is fresh or
/// stale when making subsequent requests for the same resource.
///
/// The policy is large (352 bytes at time of writing), so we reduce
/// the stack size by boxing it.
cache_policy: Option<Box<CachePolicy>>,
},
}
/// Represents an arbitrary data blob with an associated HTTP cache policy.
///
/// The cache policy is used to determine whether the data blob is stale or
/// not.
///
/// # Format
///
/// This type encapsulates the format for how blobs of data are stored on
/// disk. The format is very simple. First, the blob of data is written as-is.
/// Second, the archived representation of a `CachePolicy` is written. Thirdly,
/// the length, in bytes, of the archived `CachePolicy` is written as a 64-bit
/// little endian integer.
///
/// Reading the format is done via an `AlignedVec` so that `rkyv` can correctly
/// read the archived representation of the data blob. The cache policy is
/// split into its own `AlignedVec` allocation.
///
/// # Future ideas
///
/// This format was also chosen because it should in theory permit rewriting
/// the cache policy without needing to rewrite the data blob if the blob has
/// not changed. For example, this case occurs when a revalidation request
/// responds with HTTP 304 NOT MODIFIED. At time of writing, this is not yet
/// implemented because 1) the synchronization specifics of mutating a cache
/// file have not been worked out and 2) it's not clear if it's a win.
///
/// An alternative format would be to write the cache policy and the
/// blob in two distinct files. This would avoid needing to worry about
/// synchronization, but it means reading two files instead of one for every
/// cached response in the fast path. It's unclear whether it's worth it.
/// (Experiments have not yet been done.)
///
/// Another approach here would be to memory map the file and rejigger
/// `OwnedArchive` (or create a new type) that works with a memory map instead
/// of an `AlignedVec`. This will require care to ensure alignment is handled
/// correctly. This approach has not been litigated yet. I did not start with
/// it because experiments with ripgrep have tended to show that (on Linux)
/// memory mapping a bunch of small files ends up being quite a bit slower than
/// just reading them on to the heap.
#[derive(Debug)]
pub struct DataWithCachePolicy {
pub data: AlignedVec,
cache_policy: OwnedArchive<CachePolicy>,
}
impl DataWithCachePolicy {
/// Loads cached data and its associated HTTP cache policy from the given
/// file path in an asynchronous fashion (via `spawn_blocking`).
///
/// # Errors
///
/// If the given byte buffer is not in a valid format or if reading the
/// file given fails, then this returns an error.
async fn from_path_async(path: &Path) -> Result<Self, Error> {
let path = path.to_path_buf();
tokio::task::spawn_blocking(move || Self::from_path_sync(&path))
.await
// This just forwards panics from the closure.
.unwrap()
}
/// Loads cached data and its associated HTTP cache policy from the given
/// file path in a synchronous fashion.
///
/// # Errors
///
/// If the given byte buffer is not in a valid format or if reading the
/// file given fails, then this returns an error.
#[instrument]
fn from_path_sync(path: &Path) -> Result<Self, Error> {
let file = fs_err::File::open(path).map_err(ErrorKind::Io)?;
// Note that we don't wrap our file in a buffer because it will just
// get passed to AlignedVec::extend_from_reader, which doesn't benefit
// from an intermediary buffer. In effect, the AlignedVec acts as the
// buffer.
Self::from_reader(file)
}
/// Loads cached data and its associated HTTP cache policy from the given
/// reader.
///
/// # Errors
///
/// If the given byte buffer is not in a valid format or if the reader
/// fails, then this returns an error.
pub fn from_reader(mut rdr: impl std::io::Read) -> Result<Self, Error> {
let mut aligned_bytes = AlignedVec::new();
aligned_bytes
.extend_from_reader(&mut rdr)
.map_err(ErrorKind::Io)?;
Self::from_aligned_bytes(aligned_bytes)
}
/// Loads cached data and its associated HTTP cache policy form an in
/// memory byte buffer.
///
/// # Errors
///
/// If the given byte buffer is not in a valid format, then this
/// returns an error.
fn from_aligned_bytes(mut bytes: AlignedVec) -> Result<Self, Error> {
let cache_policy = Self::deserialize_cache_policy(&mut bytes)?;
Ok(Self {
data: bytes,
cache_policy,
})
}
/// Serializes the given cache policy and arbitrary data blob to an in
/// memory byte buffer.
///
/// # Errors
///
/// If there was a problem converting the given cache policy to its
/// serialized representation, then this routine will return an error.
fn serialize(cache_policy: &CachePolicy, data: &[u8]) -> Result<Vec<u8>, Error> {
let mut buf = vec![];
Self::serialize_to_writer(cache_policy, data, &mut buf)?;
Ok(buf)
}
/// Serializes the given cache policy and arbitrary data blob to the given
/// writer.
///
/// # Errors
///
/// If there was a problem converting the given cache policy to its
/// serialized representation or if the writer returns an error, then
/// this routine will return an error.
fn serialize_to_writer(
cache_policy: &CachePolicy,
data: &[u8],
mut wtr: impl std::io::Write,
) -> Result<(), Error> {
let cache_policy_archived = OwnedArchive::from_unarchived(cache_policy)?;
let cache_policy_bytes = OwnedArchive::as_bytes(&cache_policy_archived);
wtr.write_all(data).map_err(ErrorKind::Io)?;
wtr.write_all(cache_policy_bytes).map_err(ErrorKind::Io)?;
let len = u64::try_from(cache_policy_bytes.len()).map_err(|_| {
let msg = format!(
"failed to represent {} (length of cache policy) in a u64",
cache_policy_bytes.len()
);
ErrorKind::Io(std::io::Error::other(msg))
})?;
wtr.write_all(&len.to_le_bytes()).map_err(ErrorKind::Io)?;
Ok(())
}
/// Deserializes a `OwnedArchive<CachePolicy>` off the end of the given
/// aligned bytes. Upon success, the given bytes will only contain the
/// data itself. The bytes representing the cached policy will have been
/// removed.
///
/// # Errors
///
/// This returns an error if the cache policy could not be deserialized
/// from the end of the given bytes.
fn deserialize_cache_policy(
bytes: &mut AlignedVec,
) -> Result<OwnedArchive<CachePolicy>, Error> {
let len = Self::deserialize_cache_policy_len(bytes)?;
let cache_policy_bytes_start = bytes.len() - (len + 8);
let cache_policy_bytes = &bytes[cache_policy_bytes_start..][..len];
let mut cache_policy_bytes_aligned = AlignedVec::with_capacity(len);
cache_policy_bytes_aligned.extend_from_slice(cache_policy_bytes);
assert!(
cache_policy_bytes_start <= bytes.len(),
"slicing cache policy should result in a truncation"
);
// Technically this will keep the extra capacity used to store the
// cache policy around. But it should be pretty small, and it saves a
// realloc. (It's unclear whether that matters more or less than the
// extra memory usage.)
bytes.resize(cache_policy_bytes_start, 0);
OwnedArchive::new(cache_policy_bytes_aligned)
}
/// Deserializes the length, in bytes, of the cache policy given a complete
/// serialized byte buffer of a `DataWithCachePolicy`.
///
/// Upon success, callers are guaranteed that
/// `&bytes[bytes.len() - (len + 8)..][..len]` will not panic.
///
/// # Errors
///
/// This returns an error if the length could not be read as a `usize` or is
/// otherwise known to be invalid. (For example, it is a length that is bigger
/// than `bytes.len()`.)
fn deserialize_cache_policy_len(bytes: &[u8]) -> Result<usize, Error> {
let Some(cache_policy_len_start) = bytes.len().checked_sub(8) else {
let msg = format!(
"data-with-cache-policy buffer should be at least 8 bytes \
in length, but is {} bytes",
bytes.len(),
);
return Err(ErrorKind::ArchiveRead(msg).into());
};
let cache_policy_len_bytes = <[u8; 8]>::try_from(&bytes[cache_policy_len_start..])
.expect("cache policy length is 8 bytes");
let len_u64 = u64::from_le_bytes(cache_policy_len_bytes);
let Ok(len_usize) = usize::try_from(len_u64) else {
let msg = format!(
"data-with-cache-policy has cache policy length of {len_u64}, \
but overflows usize",
);
return Err(ErrorKind::ArchiveRead(msg).into());
};
if bytes.len() < len_usize + 8 {
let msg = format!(
"invalid cache entry: data-with-cache-policy has cache policy length of {}, \
but total buffer size is {}",
len_usize,
bytes.len(),
);
return Err(ErrorKind::ArchiveRead(msg).into());
}
Ok(len_usize)
}
}