diff --git a/crates/uv-distribution-types/src/index.rs b/crates/uv-distribution-types/src/index.rs index 04614a18e..3853915fb 100644 --- a/crates/uv-distribution-types/src/index.rs +++ b/crates/uv-distribution-types/src/index.rs @@ -3,6 +3,7 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; use thiserror::Error; +use url::Url; use uv_auth::{AuthPolicy, Credentials}; use uv_redacted::DisplaySafeUrl; @@ -23,6 +24,30 @@ pub struct IndexCacheControl { pub files: Option, } +impl IndexCacheControl { + /// Return the default Simple API cache control headers for the given index URL, if applicable. + pub fn simple_api_cache_control(_url: &Url) -> Option<&'static str> { + None + } + + /// Return the default files cache control headers for the given index URL, if applicable. + pub fn artifact_cache_control(url: &Url) -> Option<&'static str> { + if url + .host_str() + .is_some_and(|host| host.ends_with("pytorch.org")) + { + // Some wheels in the PyTorch registry were accidentally uploaded with `no-cache,no-store,must-revalidate`. + // The PyTorch team plans to correct this in the future, but in the meantime we override + // the cache control headers to allow caching of static files. + // + // See: https://github.com/pytorch/pytorch/pull/149218 + Some("max-age=365000000, immutable, public") + } else { + None + } + } +} + #[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[serde(rename_all = "kebab-case")] @@ -264,6 +289,32 @@ impl Index { IndexStatusCodeStrategy::from_index_url(self.url.url()) } } + + /// Return the cache control header for file requests to this index, if any. + pub fn artifact_cache_control(&self) -> Option<&str> { + if let Some(artifact_cache_control) = self + .cache_control + .as_ref() + .and_then(|cache_control| cache_control.files.as_deref()) + { + Some(artifact_cache_control) + } else { + IndexCacheControl::artifact_cache_control(self.url.url()) + } + } + + /// Return the cache control header for API requests to this index, if any. + pub fn simple_api_cache_control(&self) -> Option<&str> { + if let Some(api_cache_control) = self + .cache_control + .as_ref() + .and_then(|cache_control| cache_control.api.as_deref()) + { + Some(api_cache_control) + } else { + IndexCacheControl::simple_api_cache_control(self.url.url()) + } + } } impl From for Index { diff --git a/crates/uv-distribution-types/src/index_url.rs b/crates/uv-distribution-types/src/index_url.rs index a96e00f79..5ccd559fb 100644 --- a/crates/uv-distribution-types/src/index_url.rs +++ b/crates/uv-distribution-types/src/index_url.rs @@ -470,7 +470,7 @@ impl<'a> IndexLocations { pub fn simple_api_cache_control_for(&self, url: &IndexUrl) -> Option<&str> { for index in &self.indexes { if index.url() == url { - return index.cache_control.as_ref()?.api.as_deref(); + return index.simple_api_cache_control(); } } None @@ -480,7 +480,7 @@ impl<'a> IndexLocations { pub fn artifact_cache_control_for(&self, url: &IndexUrl) -> Option<&str> { for index in &self.indexes { if index.url() == url { - return index.cache_control.as_ref()?.files.as_deref(); + return index.artifact_cache_control(); } } None @@ -623,7 +623,7 @@ impl<'a> IndexUrls { pub fn simple_api_cache_control_for(&self, url: &IndexUrl) -> Option<&str> { for index in &self.indexes { if index.url() == url { - return index.cache_control.as_ref()?.api.as_deref(); + return index.simple_api_cache_control(); } } None @@ -633,7 +633,7 @@ impl<'a> IndexUrls { pub fn artifact_cache_control_for(&self, url: &IndexUrl) -> Option<&str> { for index in &self.indexes { if index.url() == url { - return index.cache_control.as_ref()?.files.as_deref(); + return index.artifact_cache_control(); } } None @@ -723,6 +723,8 @@ impl IndexCapabilities { #[cfg(test)] mod tests { use super::*; + use crate::{IndexCacheControl, IndexFormat, IndexName}; + use uv_small_str::SmallString; #[test] fn test_index_url_parse_valid_paths() { @@ -816,4 +818,88 @@ mod tests { assert_eq!(index_urls.simple_api_cache_control_for(&url3), None); assert_eq!(index_urls.artifact_cache_control_for(&url3), None); } + + #[test] + fn test_pytorch_default_cache_control() { + // Test that PyTorch indexes get default cache control from the getter methods + let indexes = vec![Index { + name: Some(IndexName::from_str("pytorch").unwrap()), + url: IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap(), + cache_control: None, // No explicit cache control + explicit: false, + default: false, + origin: None, + format: IndexFormat::Simple, + publish_url: None, + authenticate: uv_auth::AuthPolicy::default(), + ignore_error_codes: None, + }]; + + let index_urls = IndexUrls::from_indexes(indexes.clone()); + let index_locations = IndexLocations::new(indexes, Vec::new(), false); + + let pytorch_url = IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap(); + + // IndexUrls should return the default for PyTorch + assert_eq!(index_urls.simple_api_cache_control_for(&pytorch_url), None); + assert_eq!( + index_urls.artifact_cache_control_for(&pytorch_url), + Some("max-age=365000000, immutable, public") + ); + + // IndexLocations should also return the default for PyTorch + assert_eq!( + index_locations.simple_api_cache_control_for(&pytorch_url), + None + ); + assert_eq!( + index_locations.artifact_cache_control_for(&pytorch_url), + Some("max-age=365000000, immutable, public") + ); + } + + #[test] + fn test_pytorch_user_override_cache_control() { + // Test that user-specified cache control overrides PyTorch defaults + let indexes = vec![Index { + name: Some(IndexName::from_str("pytorch").unwrap()), + url: IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap(), + cache_control: Some(IndexCacheControl { + api: Some(SmallString::from("no-cache")), + files: Some(SmallString::from("max-age=3600")), + }), + explicit: false, + default: false, + origin: None, + format: IndexFormat::Simple, + publish_url: None, + authenticate: uv_auth::AuthPolicy::default(), + ignore_error_codes: None, + }]; + + let index_urls = IndexUrls::from_indexes(indexes.clone()); + let index_locations = IndexLocations::new(indexes, Vec::new(), false); + + let pytorch_url = IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap(); + + // User settings should override defaults + assert_eq!( + index_urls.simple_api_cache_control_for(&pytorch_url), + Some("no-cache") + ); + assert_eq!( + index_urls.artifact_cache_control_for(&pytorch_url), + Some("max-age=3600") + ); + + // Same for IndexLocations + assert_eq!( + index_locations.simple_api_cache_control_for(&pytorch_url), + Some("no-cache") + ); + assert_eq!( + index_locations.artifact_cache_control_for(&pytorch_url), + Some("max-age=3600") + ); + } }