This commit is contained in:
Assad Yousuf 2025-12-16 09:14:37 +01:00 committed by GitHub
commit ded2fa7b46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 130 additions and 49 deletions

View File

@ -57,6 +57,7 @@ reqwest-retry = { workspace = true }
rkyv = { workspace = true }
rmp-serde = { workspace = true }
rustc-hash = { workspace = true }
rustls = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sys-info = { workspace = true }

View File

@ -1020,6 +1020,15 @@ pub struct UvRetryableStrategy;
impl RetryableStrategy for UvRetryableStrategy {
fn handle(&self, res: &Result<Response, reqwest_middleware::Error>) -> Option<Retryable> {
if let Err(reqwest_middleware::Error::Reqwest(err)) = res
&& is_tls_request_error(err)
{
if let Some(url) = err.url() {
trace!("Cannot retry {url} due to TLS error: {err:?}");
}
return Some(Retryable::Fatal);
}
// Use the default strategy and check for additional transient error cases.
let retryable = match DefaultRetryableStrategy.handle(res) {
None | Some(Retryable::Fatal)
@ -1077,7 +1086,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool {
if let Some(reqwest_err) = source.downcast_ref::<WrappedReqwestError>() {
has_known_error = true;
if let reqwest_middleware::Error::Reqwest(reqwest_err) = &**reqwest_err {
if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
if is_tls_request_error(reqwest_err) {
trace!("Cannot retry nested reqwest middleware TLS error");
return false;
} else if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
trace!("Retrying nested reqwest middleware error");
return true;
}
@ -1090,7 +1102,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool {
trace!("Cannot retry nested reqwest middleware error");
} else if let Some(reqwest_err) = source.downcast_ref::<reqwest::Error>() {
has_known_error = true;
if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
if is_tls_request_error(reqwest_err) {
trace!("Cannot retry nested reqwest TLS error");
return false;
} else if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
trace!("Retrying nested reqwest error");
return true;
}
@ -1139,6 +1154,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool {
false
}
fn is_tls_request_error(reqwest_err: &reqwest::Error) -> bool {
reqwest_err.is_connect() && find_source_with_io::<rustls::Error>(&reqwest_err).is_some()
}
/// Whether the error is a status code error that is retryable.
///
/// Port of `reqwest_retry::default_on_request_success`.
@ -1165,6 +1184,31 @@ fn find_source<E: Error + 'static>(orig: &dyn Error) -> Option<&E> {
None
}
/// Find the first source error of a specific type while also wrapped in `io::Error`.
///
/// Inspired by <https://github.com/seanmonstar/reqwest/issues/1602#issuecomment-1220996681>
/// See <https://github.com/hyperium/h2/issues/862>
pub fn find_source_with_io<E: Error + 'static>(orig: &dyn Error) -> Option<&E> {
let mut cause = orig.source();
while let Some(err) = cause {
if let Some(concrete_err) = err.downcast_ref() {
return Some(concrete_err);
}
// Walk io::Error in case get_ref wraps the real source
if let Some(io_err) = err.downcast_ref::<io::Error>() {
if let Some(inner_err) = io_err.get_ref() {
if let Some(concrete_err) = inner_err.downcast_ref() {
return Some(concrete_err);
}
cause = Some(inner_err);
continue;
}
}
cause = err.source();
}
None
}
// TODO(konsti): Remove once we find a native home for `retries_from_env`
#[derive(Debug, Error)]
pub enum RetryParsingError {

View File

@ -1,7 +1,7 @@
pub use base_client::{
AuthIntegration, BaseClient, BaseClientBuilder, DEFAULT_RETRIES, ExtraMiddleware,
RedirectClientWithMiddleware, RequestBuilder, RetryParsingError, UvRetryableStrategy,
is_transient_network_error,
find_source_with_io, is_transient_network_error,
};
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy};
pub use error::{Error, ErrorKind, WrappedReqwestError};

View File

@ -368,6 +368,18 @@ pub(crate) async fn start_https_user_agent_server(
.await
}
/// Single Request HTTPS server with a self-signed CA that echoes the User Agent Header.
pub(crate) async fn start_https_ca_user_agent_server(
ca_cert: &SelfSigned,
server_cert: &SelfSigned,
) -> Result<(JoinHandle<Result<()>>, SocketAddr)> {
TestServerBuilder::new()
.with_ca_cert(ca_cert)
.with_server_cert(server_cert)
.start()
.await
}
/// Single Request HTTPS mTLS server that echoes the User Agent Header.
pub(crate) async fn start_https_mtls_user_agent_server(
ca_cert: &SelfSigned,

View File

@ -1,3 +1,4 @@
use std::error::Error;
use std::str::FromStr;
use anyhow::Result;
@ -12,10 +13,57 @@ use uv_static::EnvVars;
use crate::http_util::{
generate_self_signed_certs, generate_self_signed_certs_with_ca,
start_https_mtls_user_agent_server, start_https_user_agent_server, test_cert_dir,
start_https_ca_user_agent_server, start_https_mtls_user_agent_server,
start_https_user_agent_server, test_cert_dir,
};
// SAFETY: This test is meant to run with single thread configuration
#[tokio::test]
async fn ssl_retry_once() -> Result<()> {
// Generate self-signed CA, server, and client certs
let (ca_cert, server_cert, _) = generate_self_signed_certs_with_ca()?;
let (server_task, addr) = start_https_ca_user_agent_server(&ca_cert, &server_cert).await?;
let url = DisplaySafeUrl::from_str(&format!("https://{addr}"))?;
let cache = Cache::temp()?.init()?;
let client = RegistryClientBuilder::new(BaseClientBuilder::default(), cache).build();
let res = client
.cached_client()
.uncached()
.for_host(&url)
.get(Url::from(url))
.send()
.await;
let _ = server_task.await?;
// Validate the client error
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
panic!("expected middleware error");
};
// No retries should occur (we can directly get the hyper rustls error)
// We're explicit with our chains to be sensitive to any dependency changes
let expected_err = if let Some(err) = middleware_error.source()
&& let Some(err) = err.downcast_ref::<hyper_util::client::legacy::Error>()
&& let Some(err) = err.source()
&& let Some(err) = err.downcast_ref::<std::io::Error>()
&& let Some(err) = err.get_ref()
&& let Some(err) = err.downcast_ref::<std::io::Error>()
&& let Some(err) = err.get_ref()
&& let Some(err) = err.downcast_ref::<rustls::Error>()
&& matches!(
err,
rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)
) {
true
} else {
false
};
assert!(expected_err);
Ok(())
}
// SAFETY: This test is meant to run in isolation
#[tokio::test]
#[allow(unsafe_code)]
async fn ssl_env_vars() -> Result<()> {
@ -97,23 +145,16 @@ async fn ssl_env_vars() -> Result<()> {
std::env::remove_var(EnvVars::SSL_CERT_FILE);
}
// Validate the client error
// Validate the client error - TLS errors return Fatal early so we get Middleware variant
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
panic!("expected middleware error");
};
let reqwest_error = middleware_error
.chain()
.find_map(|err| {
err.downcast_ref::<reqwest_middleware::Error>().map(|err| {
if let reqwest_middleware::Error::Reqwest(inner) = err {
inner
} else {
panic!("expected reqwest error")
}
})
})
.expect("expected reqwest error");
assert!(reqwest_error.is_connect());
// TLS errors are deeply nested in io::Error::get_ref() - use find_source_with_io to find them
assert!(
uv_client::find_source_with_io::<rustls::Error>(middleware_error.as_ref()).is_some(),
"Expected TLS error in chain"
);
// Validate the server error
let server_res = server_task.await?;
@ -207,23 +248,16 @@ async fn ssl_env_vars() -> Result<()> {
std::env::remove_var(EnvVars::SSL_CERT_DIR);
}
// Validate the client error
// Validate the client error - TLS errors return Fatal early so we get Middleware variant
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
panic!("expected middleware error");
};
let reqwest_error = middleware_error
.chain()
.find_map(|err| {
err.downcast_ref::<reqwest_middleware::Error>().map(|err| {
if let reqwest_middleware::Error::Reqwest(inner) = err {
inner
} else {
panic!("expected reqwest error")
}
})
})
.expect("expected reqwest error");
assert!(reqwest_error.is_connect());
// TLS errors are deeply nested in io::Error::get_ref() - use find_source_with_io to find them
assert!(
uv_client::find_source_with_io::<rustls::Error>(middleware_error.as_ref()).is_some(),
"Expected TLS error in chain"
);
// Validate the server error
let server_res = server_task.await?;
@ -296,23 +330,13 @@ async fn ssl_env_vars() -> Result<()> {
std::env::remove_var(EnvVars::SSL_CERT_FILE);
}
// Validate the client error
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
panic!("expected middleware error");
// Validate the client error - this is an mTLS failure (no client cert provided)
// The server closes the connection during handshake, so the client sees a
// generic connection error (e.g., "Connection refused"), not a TLS certificate error
let Err(reqwest_middleware::Error::Middleware(_middleware_error)) = res else {
panic!("expected middleware error, got: {res:?}");
};
let reqwest_error = middleware_error
.chain()
.find_map(|err| {
err.downcast_ref::<reqwest_middleware::Error>().map(|err| {
if let reqwest_middleware::Error::Reqwest(inner) = err {
inner
} else {
panic!("expected reqwest error")
}
})
})
.expect("expected reqwest error");
assert!(reqwest_error.is_connect());
// For mTLS, just verify we got an error - the server error below confirms it's TLS-related
// Validate the server error
let server_res = server_task.await?;