mirror of https://github.com/astral-sh/uv
Merge 0a5bfc96ad into 13e7ad62cb
This commit is contained in:
commit
ded2fa7b46
|
|
@ -57,6 +57,7 @@ reqwest-retry = { workspace = true }
|
|||
rkyv = { workspace = true }
|
||||
rmp-serde = { workspace = true }
|
||||
rustc-hash = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sys-info = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -1020,6 +1020,15 @@ pub struct UvRetryableStrategy;
|
|||
|
||||
impl RetryableStrategy for UvRetryableStrategy {
|
||||
fn handle(&self, res: &Result<Response, reqwest_middleware::Error>) -> Option<Retryable> {
|
||||
if let Err(reqwest_middleware::Error::Reqwest(err)) = res
|
||||
&& is_tls_request_error(err)
|
||||
{
|
||||
if let Some(url) = err.url() {
|
||||
trace!("Cannot retry {url} due to TLS error: {err:?}");
|
||||
}
|
||||
return Some(Retryable::Fatal);
|
||||
}
|
||||
|
||||
// Use the default strategy and check for additional transient error cases.
|
||||
let retryable = match DefaultRetryableStrategy.handle(res) {
|
||||
None | Some(Retryable::Fatal)
|
||||
|
|
@ -1077,7 +1086,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool {
|
|||
if let Some(reqwest_err) = source.downcast_ref::<WrappedReqwestError>() {
|
||||
has_known_error = true;
|
||||
if let reqwest_middleware::Error::Reqwest(reqwest_err) = &**reqwest_err {
|
||||
if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
|
||||
if is_tls_request_error(reqwest_err) {
|
||||
trace!("Cannot retry nested reqwest middleware TLS error");
|
||||
return false;
|
||||
} else if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
|
||||
trace!("Retrying nested reqwest middleware error");
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1090,7 +1102,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool {
|
|||
trace!("Cannot retry nested reqwest middleware error");
|
||||
} else if let Some(reqwest_err) = source.downcast_ref::<reqwest::Error>() {
|
||||
has_known_error = true;
|
||||
if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
|
||||
if is_tls_request_error(reqwest_err) {
|
||||
trace!("Cannot retry nested reqwest TLS error");
|
||||
return false;
|
||||
} else if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
|
||||
trace!("Retrying nested reqwest error");
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1139,6 +1154,10 @@ pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool {
|
|||
false
|
||||
}
|
||||
|
||||
fn is_tls_request_error(reqwest_err: &reqwest::Error) -> bool {
|
||||
reqwest_err.is_connect() && find_source_with_io::<rustls::Error>(&reqwest_err).is_some()
|
||||
}
|
||||
|
||||
/// Whether the error is a status code error that is retryable.
|
||||
///
|
||||
/// Port of `reqwest_retry::default_on_request_success`.
|
||||
|
|
@ -1165,6 +1184,31 @@ fn find_source<E: Error + 'static>(orig: &dyn Error) -> Option<&E> {
|
|||
None
|
||||
}
|
||||
|
||||
/// Find the first source error of a specific type while also wrapped in `io::Error`.
|
||||
///
|
||||
/// Inspired by <https://github.com/seanmonstar/reqwest/issues/1602#issuecomment-1220996681>
|
||||
/// See <https://github.com/hyperium/h2/issues/862>
|
||||
pub fn find_source_with_io<E: Error + 'static>(orig: &dyn Error) -> Option<&E> {
|
||||
let mut cause = orig.source();
|
||||
while let Some(err) = cause {
|
||||
if let Some(concrete_err) = err.downcast_ref() {
|
||||
return Some(concrete_err);
|
||||
}
|
||||
// Walk io::Error in case get_ref wraps the real source
|
||||
if let Some(io_err) = err.downcast_ref::<io::Error>() {
|
||||
if let Some(inner_err) = io_err.get_ref() {
|
||||
if let Some(concrete_err) = inner_err.downcast_ref() {
|
||||
return Some(concrete_err);
|
||||
}
|
||||
cause = Some(inner_err);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
cause = err.source();
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// TODO(konsti): Remove once we find a native home for `retries_from_env`
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RetryParsingError {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pub use base_client::{
|
||||
AuthIntegration, BaseClient, BaseClientBuilder, DEFAULT_RETRIES, ExtraMiddleware,
|
||||
RedirectClientWithMiddleware, RequestBuilder, RetryParsingError, UvRetryableStrategy,
|
||||
is_transient_network_error,
|
||||
find_source_with_io, is_transient_network_error,
|
||||
};
|
||||
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy};
|
||||
pub use error::{Error, ErrorKind, WrappedReqwestError};
|
||||
|
|
|
|||
|
|
@ -368,6 +368,18 @@ pub(crate) async fn start_https_user_agent_server(
|
|||
.await
|
||||
}
|
||||
|
||||
/// Single Request HTTPS server with a self-signed CA that echoes the User Agent Header.
|
||||
pub(crate) async fn start_https_ca_user_agent_server(
|
||||
ca_cert: &SelfSigned,
|
||||
server_cert: &SelfSigned,
|
||||
) -> Result<(JoinHandle<Result<()>>, SocketAddr)> {
|
||||
TestServerBuilder::new()
|
||||
.with_ca_cert(ca_cert)
|
||||
.with_server_cert(server_cert)
|
||||
.start()
|
||||
.await
|
||||
}
|
||||
|
||||
/// Single Request HTTPS mTLS server that echoes the User Agent Header.
|
||||
pub(crate) async fn start_https_mtls_user_agent_server(
|
||||
ca_cert: &SelfSigned,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use std::error::Error;
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::Result;
|
||||
|
|
@ -12,10 +13,57 @@ use uv_static::EnvVars;
|
|||
|
||||
use crate::http_util::{
|
||||
generate_self_signed_certs, generate_self_signed_certs_with_ca,
|
||||
start_https_mtls_user_agent_server, start_https_user_agent_server, test_cert_dir,
|
||||
start_https_ca_user_agent_server, start_https_mtls_user_agent_server,
|
||||
start_https_user_agent_server, test_cert_dir,
|
||||
};
|
||||
|
||||
// SAFETY: This test is meant to run with single thread configuration
|
||||
#[tokio::test]
|
||||
async fn ssl_retry_once() -> Result<()> {
|
||||
// Generate self-signed CA, server, and client certs
|
||||
let (ca_cert, server_cert, _) = generate_self_signed_certs_with_ca()?;
|
||||
|
||||
let (server_task, addr) = start_https_ca_user_agent_server(&ca_cert, &server_cert).await?;
|
||||
let url = DisplaySafeUrl::from_str(&format!("https://{addr}"))?;
|
||||
let cache = Cache::temp()?.init()?;
|
||||
let client = RegistryClientBuilder::new(BaseClientBuilder::default(), cache).build();
|
||||
let res = client
|
||||
.cached_client()
|
||||
.uncached()
|
||||
.for_host(&url)
|
||||
.get(Url::from(url))
|
||||
.send()
|
||||
.await;
|
||||
let _ = server_task.await?;
|
||||
|
||||
// Validate the client error
|
||||
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
|
||||
panic!("expected middleware error");
|
||||
};
|
||||
|
||||
// No retries should occur (we can directly get the hyper rustls error)
|
||||
// We're explicit with our chains to be sensitive to any dependency changes
|
||||
let expected_err = if let Some(err) = middleware_error.source()
|
||||
&& let Some(err) = err.downcast_ref::<hyper_util::client::legacy::Error>()
|
||||
&& let Some(err) = err.source()
|
||||
&& let Some(err) = err.downcast_ref::<std::io::Error>()
|
||||
&& let Some(err) = err.get_ref()
|
||||
&& let Some(err) = err.downcast_ref::<std::io::Error>()
|
||||
&& let Some(err) = err.get_ref()
|
||||
&& let Some(err) = err.downcast_ref::<rustls::Error>()
|
||||
&& matches!(
|
||||
err,
|
||||
rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)
|
||||
) {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
assert!(expected_err);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// SAFETY: This test is meant to run in isolation
|
||||
#[tokio::test]
|
||||
#[allow(unsafe_code)]
|
||||
async fn ssl_env_vars() -> Result<()> {
|
||||
|
|
@ -97,23 +145,16 @@ async fn ssl_env_vars() -> Result<()> {
|
|||
std::env::remove_var(EnvVars::SSL_CERT_FILE);
|
||||
}
|
||||
|
||||
// Validate the client error
|
||||
// Validate the client error - TLS errors return Fatal early so we get Middleware variant
|
||||
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
|
||||
panic!("expected middleware error");
|
||||
};
|
||||
let reqwest_error = middleware_error
|
||||
.chain()
|
||||
.find_map(|err| {
|
||||
err.downcast_ref::<reqwest_middleware::Error>().map(|err| {
|
||||
if let reqwest_middleware::Error::Reqwest(inner) = err {
|
||||
inner
|
||||
} else {
|
||||
panic!("expected reqwest error")
|
||||
}
|
||||
})
|
||||
})
|
||||
.expect("expected reqwest error");
|
||||
assert!(reqwest_error.is_connect());
|
||||
|
||||
// TLS errors are deeply nested in io::Error::get_ref() - use find_source_with_io to find them
|
||||
assert!(
|
||||
uv_client::find_source_with_io::<rustls::Error>(middleware_error.as_ref()).is_some(),
|
||||
"Expected TLS error in chain"
|
||||
);
|
||||
|
||||
// Validate the server error
|
||||
let server_res = server_task.await?;
|
||||
|
|
@ -207,23 +248,16 @@ async fn ssl_env_vars() -> Result<()> {
|
|||
std::env::remove_var(EnvVars::SSL_CERT_DIR);
|
||||
}
|
||||
|
||||
// Validate the client error
|
||||
// Validate the client error - TLS errors return Fatal early so we get Middleware variant
|
||||
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
|
||||
panic!("expected middleware error");
|
||||
};
|
||||
let reqwest_error = middleware_error
|
||||
.chain()
|
||||
.find_map(|err| {
|
||||
err.downcast_ref::<reqwest_middleware::Error>().map(|err| {
|
||||
if let reqwest_middleware::Error::Reqwest(inner) = err {
|
||||
inner
|
||||
} else {
|
||||
panic!("expected reqwest error")
|
||||
}
|
||||
})
|
||||
})
|
||||
.expect("expected reqwest error");
|
||||
assert!(reqwest_error.is_connect());
|
||||
|
||||
// TLS errors are deeply nested in io::Error::get_ref() - use find_source_with_io to find them
|
||||
assert!(
|
||||
uv_client::find_source_with_io::<rustls::Error>(middleware_error.as_ref()).is_some(),
|
||||
"Expected TLS error in chain"
|
||||
);
|
||||
|
||||
// Validate the server error
|
||||
let server_res = server_task.await?;
|
||||
|
|
@ -296,23 +330,13 @@ async fn ssl_env_vars() -> Result<()> {
|
|||
std::env::remove_var(EnvVars::SSL_CERT_FILE);
|
||||
}
|
||||
|
||||
// Validate the client error
|
||||
let Some(reqwest_middleware::Error::Middleware(middleware_error)) = res.err() else {
|
||||
panic!("expected middleware error");
|
||||
// Validate the client error - this is an mTLS failure (no client cert provided)
|
||||
// The server closes the connection during handshake, so the client sees a
|
||||
// generic connection error (e.g., "Connection refused"), not a TLS certificate error
|
||||
let Err(reqwest_middleware::Error::Middleware(_middleware_error)) = res else {
|
||||
panic!("expected middleware error, got: {res:?}");
|
||||
};
|
||||
let reqwest_error = middleware_error
|
||||
.chain()
|
||||
.find_map(|err| {
|
||||
err.downcast_ref::<reqwest_middleware::Error>().map(|err| {
|
||||
if let reqwest_middleware::Error::Reqwest(inner) = err {
|
||||
inner
|
||||
} else {
|
||||
panic!("expected reqwest error")
|
||||
}
|
||||
})
|
||||
})
|
||||
.expect("expected reqwest error");
|
||||
assert!(reqwest_error.is_connect());
|
||||
// For mTLS, just verify we got an error - the server error below confirms it's TLS-related
|
||||
|
||||
// Validate the server error
|
||||
let server_res = server_task.await?;
|
||||
|
|
|
|||
Loading…
Reference in New Issue