diff --git a/crates/uv-client/Cargo.toml b/crates/uv-client/Cargo.toml index e54f89860..53519b2c5 100644 --- a/crates/uv-client/Cargo.toml +++ b/crates/uv-client/Cargo.toml @@ -57,6 +57,7 @@ reqwest-retry = { workspace = true } rkyv = { workspace = true } rmp-serde = { workspace = true } rustc-hash = { workspace = true } +rustls = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sys-info = { workspace = true } diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs index 33cd9e0d6..68a45a4f2 100644 --- a/crates/uv-client/src/base_client.rs +++ b/crates/uv-client/src/base_client.rs @@ -1020,6 +1020,15 @@ pub struct UvRetryableStrategy; impl RetryableStrategy for UvRetryableStrategy { fn handle(&self, res: &Result) -> Option { + if let Err(reqwest_middleware::Error::Reqwest(err)) = res + && is_tls_request_error(err) + { + if let Some(url) = err.url() { + trace!("Cannot retry {url} due to TLS error: {err:?}"); + } + return Some(Retryable::Fatal); + } + // Use the default strategy and check for additional transient error cases. let retryable = match DefaultRetryableStrategy.handle(res) { None | Some(Retryable::Fatal) @@ -1077,7 +1086,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool { if let Some(reqwest_err) = source.downcast_ref::() { has_known_error = true; if let reqwest_middleware::Error::Reqwest(reqwest_err) = &**reqwest_err { - if default_on_request_error(reqwest_err) == Some(Retryable::Transient) { + if is_tls_request_error(reqwest_err) { + trace!("Cannot retry nested reqwest middleware TLS error"); + return false; + } else if default_on_request_error(reqwest_err) == Some(Retryable::Transient) { trace!("Retrying nested reqwest middleware error"); return true; } @@ -1090,7 +1102,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool { trace!("Cannot retry nested reqwest middleware error"); } else if let Some(reqwest_err) = source.downcast_ref::() { has_known_error = true; - if default_on_request_error(reqwest_err) == Some(Retryable::Transient) { + if is_tls_request_error(reqwest_err) { + trace!("Cannot retry nested reqwest TLS error"); + return false; + } else if default_on_request_error(reqwest_err) == Some(Retryable::Transient) { trace!("Retrying nested reqwest error"); return true; } @@ -1139,6 +1154,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool { false } +fn is_tls_request_error(reqwest_err: &reqwest::Error) -> bool { + reqwest_err.is_connect() && find_source_with_io::(&reqwest_err).is_some() +} + /// Whether the error is a status code error that is retryable. /// /// Port of `reqwest_retry::default_on_request_success`. @@ -1165,6 +1184,31 @@ fn find_source(orig: &dyn Error) -> Option<&E> { None } +/// Find the first source error of a specific type while also wrapped in `io::Error`. +/// +/// Inspired by +/// See +pub fn find_source_with_io(orig: &dyn Error) -> Option<&E> { + let mut cause = orig.source(); + while let Some(err) = cause { + if let Some(concrete_err) = err.downcast_ref() { + return Some(concrete_err); + } + // Walk io::Error in case get_ref wraps the real source + if let Some(io_err) = err.downcast_ref::() { + if let Some(inner_err) = io_err.get_ref() { + if let Some(concrete_err) = inner_err.downcast_ref() { + return Some(concrete_err); + } + cause = Some(inner_err); + continue; + } + } + cause = err.source(); + } + None +} + // TODO(konsti): Remove once we find a native home for `retries_from_env` #[derive(Debug, Error)] pub enum RetryParsingError { diff --git a/crates/uv-client/src/lib.rs b/crates/uv-client/src/lib.rs index 2862fc6d1..c0043980d 100644 --- a/crates/uv-client/src/lib.rs +++ b/crates/uv-client/src/lib.rs @@ -1,7 +1,7 @@ pub use base_client::{ AuthIntegration, BaseClient, BaseClientBuilder, DEFAULT_RETRIES, ExtraMiddleware, RedirectClientWithMiddleware, RequestBuilder, RetryParsingError, UvRetryableStrategy, - is_transient_network_error, + find_source_with_io, is_transient_network_error, }; pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use error::{Error, ErrorKind, WrappedReqwestError}; diff --git a/crates/uv-client/tests/it/http_util.rs b/crates/uv-client/tests/it/http_util.rs index b879699b3..855809b29 100644 --- a/crates/uv-client/tests/it/http_util.rs +++ b/crates/uv-client/tests/it/http_util.rs @@ -368,6 +368,18 @@ pub(crate) async fn start_https_user_agent_server( .await } +/// Single Request HTTPS server with a self-signed CA that echoes the User Agent Header. +pub(crate) async fn start_https_ca_user_agent_server( + ca_cert: &SelfSigned, + server_cert: &SelfSigned, +) -> Result<(JoinHandle>, SocketAddr)> { + TestServerBuilder::new() + .with_ca_cert(ca_cert) + .with_server_cert(server_cert) + .start() + .await +} + /// Single Request HTTPS mTLS server that echoes the User Agent Header. pub(crate) async fn start_https_mtls_user_agent_server( ca_cert: &SelfSigned, diff --git a/crates/uv-client/tests/it/ssl_certs.rs b/crates/uv-client/tests/it/ssl_certs.rs index 6c88cd7d5..dede8c62e 100644 --- a/crates/uv-client/tests/it/ssl_certs.rs +++ b/crates/uv-client/tests/it/ssl_certs.rs @@ -1,3 +1,4 @@ +use std::error::Error; use std::str::FromStr; use anyhow::Result; @@ -12,10 +13,57 @@ use uv_static::EnvVars; use crate::http_util::{ generate_self_signed_certs, generate_self_signed_certs_with_ca, - start_https_mtls_user_agent_server, start_https_user_agent_server, test_cert_dir, + start_https_ca_user_agent_server, start_https_mtls_user_agent_server, + start_https_user_agent_server, test_cert_dir, }; -// SAFETY: This test is meant to run with single thread configuration +#[tokio::test] +async fn ssl_retry_once() -> Result<()> { + // Generate self-signed CA, server, and client certs + let (ca_cert, server_cert, _) = generate_self_signed_certs_with_ca()?; + + let (server_task, addr) = start_https_ca_user_agent_server(&ca_cert, &server_cert).await?; + let url = DisplaySafeUrl::from_str(&format!("https://{addr}"))?; + let cache = Cache::temp()?.init()?; + let client = RegistryClientBuilder::new(BaseClientBuilder::default(), cache).build(); + let res = client + .cached_client() + .uncached() + .for_host(&url) + .get(Url::from(url)) + .send() + .await; + let _ = server_task.await?; + + // Validate the client error + let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else { + panic!("expected middleware error"); + }; + + // No retries should occur (we can directly get the hyper rustls error) + // We're explicit with our chains to be sensitive to any dependency changes + let expected_err = if let Some(err) = middleware_error.source() + && let Some(err) = err.downcast_ref::() + && let Some(err) = err.source() + && let Some(err) = err.downcast_ref::() + && let Some(err) = err.get_ref() + && let Some(err) = err.downcast_ref::() + && let Some(err) = err.get_ref() + && let Some(err) = err.downcast_ref::() + && matches!( + err, + rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer) + ) { + true + } else { + false + }; + assert!(expected_err); + + Ok(()) +} + +// SAFETY: This test is meant to run in isolation #[tokio::test] #[allow(unsafe_code)] async fn ssl_env_vars() -> Result<()> { @@ -97,23 +145,16 @@ async fn ssl_env_vars() -> Result<()> { std::env::remove_var(EnvVars::SSL_CERT_FILE); } - // Validate the client error + // Validate the client error - TLS errors return Fatal early so we get Middleware variant let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else { panic!("expected middleware error"); }; - let reqwest_error = middleware_error - .chain() - .find_map(|err| { - err.downcast_ref::().map(|err| { - if let reqwest_middleware::Error::Reqwest(inner) = err { - inner - } else { - panic!("expected reqwest error") - } - }) - }) - .expect("expected reqwest error"); - assert!(reqwest_error.is_connect()); + + // TLS errors are deeply nested in io::Error::get_ref() - use find_source_with_io to find them + assert!( + uv_client::find_source_with_io::(middleware_error.as_ref()).is_some(), + "Expected TLS error in chain" + ); // Validate the server error let server_res = server_task.await?; @@ -207,23 +248,16 @@ async fn ssl_env_vars() -> Result<()> { std::env::remove_var(EnvVars::SSL_CERT_DIR); } - // Validate the client error + // Validate the client error - TLS errors return Fatal early so we get Middleware variant let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else { panic!("expected middleware error"); }; - let reqwest_error = middleware_error - .chain() - .find_map(|err| { - err.downcast_ref::().map(|err| { - if let reqwest_middleware::Error::Reqwest(inner) = err { - inner - } else { - panic!("expected reqwest error") - } - }) - }) - .expect("expected reqwest error"); - assert!(reqwest_error.is_connect()); + + // TLS errors are deeply nested in io::Error::get_ref() - use find_source_with_io to find them + assert!( + uv_client::find_source_with_io::(middleware_error.as_ref()).is_some(), + "Expected TLS error in chain" + ); // Validate the server error let server_res = server_task.await?; @@ -296,23 +330,13 @@ async fn ssl_env_vars() -> Result<()> { std::env::remove_var(EnvVars::SSL_CERT_FILE); } - // Validate the client error - let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else { - panic!("expected middleware error"); + // Validate the client error - this is an mTLS failure (no client cert provided) + // The server closes the connection during handshake, so the client sees a + // generic connection error (e.g., "Connection refused"), not a TLS certificate error + let Err(reqwest_middleware::Error::Middleware(_middleware_error)) = res else { + panic!("expected middleware error, got: {res:?}"); }; - let reqwest_error = middleware_error - .chain() - .find_map(|err| { - err.downcast_ref::().map(|err| { - if let reqwest_middleware::Error::Reqwest(inner) = err { - inner - } else { - panic!("expected reqwest error") - } - }) - }) - .expect("expected reqwest error"); - assert!(reqwest_error.is_connect()); + // For mTLS, just verify we got an error - the server error below confirms it's TLS-related // Validate the server error let server_res = server_task.await?;