use std::error::Error; use std::fmt::Debug; use std::fmt::Write; use std::num::ParseIntError; use std::path::Path; use std::sync::Arc; use std::time::Duration; use std::{env, io, iter}; use anyhow::anyhow; use http::{ HeaderMap, HeaderName, HeaderValue, Method, StatusCode, header::{ AUTHORIZATION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, LOCATION, PROXY_AUTHORIZATION, REFERER, TRANSFER_ENCODING, WWW_AUTHENTICATE, }, }; use itertools::Itertools; use reqwest::{Client, ClientBuilder, IntoUrl, Proxy, Request, Response, multipart}; use reqwest_middleware::{ClientWithMiddleware, Middleware}; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{ DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_error, }; use thiserror::Error; use tracing::{debug, trace}; use url::ParseError; use url::Url; use uv_auth::{AuthMiddleware, Credentials, Indexes, PyxTokenStore}; use uv_configuration::{KeyringProviderType, TrustedHost}; use uv_fs::Simplified; use uv_pep508::MarkerEnvironment; use uv_platform_tags::Platform; use uv_preview::Preview; use uv_redacted::DisplaySafeUrl; use uv_redacted::DisplaySafeUrlError; use uv_static::EnvVars; use uv_version::version; use uv_warnings::warn_user_once; use crate::linehaul::LineHaul; use crate::middleware::OfflineMiddleware; use crate::tls::read_identity; use crate::{Connectivity, WrappedReqwestError}; pub const DEFAULT_RETRIES: u32 = 3; /// Maximum number of redirects to follow before giving up. /// /// This is the default used by [`reqwest`]. const DEFAULT_MAX_REDIRECTS: u32 = 10; /// Selectively skip parts or the entire auth middleware. #[derive(Debug, Clone, Copy, Default)] pub enum AuthIntegration { /// Run the full auth middleware, including sending an unauthenticated request first. #[default] Default, /// Send only an authenticated request without cloning and sending an unauthenticated request /// first. Errors if no credentials were found. OnlyAuthenticated, /// Skip the auth middleware entirely. The caller is responsible for managing authentication. NoAuthMiddleware, } /// A builder for an [`BaseClient`]. #[derive(Debug, Clone)] pub struct BaseClientBuilder<'a> { keyring: KeyringProviderType, preview: Preview, allow_insecure_host: Vec, native_tls: bool, built_in_root_certs: bool, retries: u32, pub connectivity: Connectivity, markers: Option<&'a MarkerEnvironment>, platform: Option<&'a Platform>, auth_integration: AuthIntegration, indexes: Indexes, timeout: Duration, extra_middleware: Option, proxies: Vec, redirect_policy: RedirectPolicy, /// Whether credentials should be propagated during cross-origin redirects. /// /// A policy allowing propagation is insecure and should only be available for test code. cross_origin_credential_policy: CrossOriginCredentialsPolicy, /// Optional custom reqwest client to use instead of creating a new one. custom_client: Option, /// uv subcommand in which this client is being used subcommand: Option>, } /// The policy for handling HTTP redirects. #[derive(Debug, Default, Clone, Copy)] pub enum RedirectPolicy { /// Use reqwest's built-in redirect handling. This bypasses our custom middleware /// on redirect. #[default] BypassMiddleware, /// Handle redirects manually, re-triggering our custom middleware for each request. RetriggerMiddleware, } impl RedirectPolicy { pub fn reqwest_policy(self) -> reqwest::redirect::Policy { match self { Self::BypassMiddleware => reqwest::redirect::Policy::default(), Self::RetriggerMiddleware => reqwest::redirect::Policy::none(), } } } /// A list of user-defined middlewares to be applied to the client. #[derive(Clone)] pub struct ExtraMiddleware(pub Vec>); impl Debug for ExtraMiddleware { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ExtraMiddleware") .field("0", &format!("{} middlewares", self.0.len())) .finish() } } impl Default for BaseClientBuilder<'_> { fn default() -> Self { Self { keyring: KeyringProviderType::default(), preview: Preview::default(), allow_insecure_host: vec![], native_tls: false, built_in_root_certs: false, connectivity: Connectivity::Online, retries: DEFAULT_RETRIES, markers: None, platform: None, auth_integration: AuthIntegration::default(), indexes: Indexes::new(), timeout: Duration::from_secs(30), extra_middleware: None, proxies: vec![], redirect_policy: RedirectPolicy::default(), cross_origin_credential_policy: CrossOriginCredentialsPolicy::Secure, custom_client: None, subcommand: None, } } } impl BaseClientBuilder<'_> { pub fn new( connectivity: Connectivity, native_tls: bool, allow_insecure_host: Vec, preview: Preview, timeout: Duration, retries: u32, ) -> Self { Self { preview, allow_insecure_host, native_tls, retries, connectivity, timeout, ..Self::default() } } } impl<'a> BaseClientBuilder<'a> { /// Use a custom reqwest client instead of creating a new one. /// /// This allows you to provide your own reqwest client with custom configuration. /// Note that some configuration options from this builder will still be applied /// to the client via middleware. #[must_use] pub fn custom_client(mut self, client: Client) -> Self { self.custom_client = Some(client); self } #[must_use] pub fn keyring(mut self, keyring_type: KeyringProviderType) -> Self { self.keyring = keyring_type; self } #[must_use] pub fn allow_insecure_host(mut self, allow_insecure_host: Vec) -> Self { self.allow_insecure_host = allow_insecure_host; self } #[must_use] pub fn connectivity(mut self, connectivity: Connectivity) -> Self { self.connectivity = connectivity; self } #[must_use] pub fn retries(mut self, retries: u32) -> Self { self.retries = retries; self } #[must_use] pub fn native_tls(mut self, native_tls: bool) -> Self { self.native_tls = native_tls; self } #[must_use] pub fn built_in_root_certs(mut self, built_in_root_certs: bool) -> Self { self.built_in_root_certs = built_in_root_certs; self } #[must_use] pub fn markers(mut self, markers: &'a MarkerEnvironment) -> Self { self.markers = Some(markers); self } #[must_use] pub fn platform(mut self, platform: &'a Platform) -> Self { self.platform = Some(platform); self } #[must_use] pub fn auth_integration(mut self, auth_integration: AuthIntegration) -> Self { self.auth_integration = auth_integration; self } #[must_use] pub fn indexes(mut self, indexes: Indexes) -> Self { self.indexes = indexes; self } #[must_use] pub fn timeout(mut self, timeout: Duration) -> Self { self.timeout = timeout; self } #[must_use] pub fn extra_middleware(mut self, middleware: ExtraMiddleware) -> Self { self.extra_middleware = Some(middleware); self } #[must_use] pub fn proxy(mut self, proxy: Proxy) -> Self { self.proxies.push(proxy); self } #[must_use] pub fn redirect(mut self, policy: RedirectPolicy) -> Self { self.redirect_policy = policy; self } /// Allows credentials to be propagated on cross-origin redirects. /// /// WARNING: This should only be available for tests. In production code, propagating credentials /// during cross-origin redirects can lead to security vulnerabilities including credential /// leakage to untrusted domains. #[cfg(test)] #[must_use] pub fn allow_cross_origin_credentials(mut self) -> Self { self.cross_origin_credential_policy = CrossOriginCredentialsPolicy::Insecure; self } #[must_use] pub fn subcommand(mut self, subcommand: Vec) -> Self { self.subcommand = Some(subcommand); self } pub fn is_native_tls(&self) -> bool { self.native_tls } pub fn is_offline(&self) -> bool { matches!(self.connectivity, Connectivity::Offline) } /// Create a [`RetryPolicy`] for the client. pub fn retry_policy(&self) -> ExponentialBackoff { let mut builder = ExponentialBackoff::builder(); if env::var_os(EnvVars::UV_TEST_NO_HTTP_RETRY_DELAY).is_some() { builder = builder.retry_bounds(Duration::from_millis(0), Duration::from_millis(0)); } builder.build_with_max_retries(self.retries) } pub fn build(&self) -> BaseClient { let timeout = self.timeout; debug!("Using request timeout of {}s", timeout.as_secs()); // Use the custom client if provided, otherwise create a new one let (raw_client, raw_dangerous_client) = match &self.custom_client { Some(client) => (client.clone(), client.clone()), None => self.create_secure_and_insecure_clients(timeout), }; // Wrap in any relevant middleware and handle connectivity. let client = RedirectClientWithMiddleware { client: self.apply_middleware(raw_client.clone()), redirect_policy: self.redirect_policy, cross_origin_credentials_policy: self.cross_origin_credential_policy, }; let dangerous_client = RedirectClientWithMiddleware { client: self.apply_middleware(raw_dangerous_client.clone()), redirect_policy: self.redirect_policy, cross_origin_credentials_policy: self.cross_origin_credential_policy, }; BaseClient { connectivity: self.connectivity, allow_insecure_host: self.allow_insecure_host.clone(), retries: self.retries, client, raw_client, dangerous_client, raw_dangerous_client, timeout, } } /// Share the underlying client between two different middleware configurations. pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient { // Wrap in any relevant middleware and handle connectivity. let client = RedirectClientWithMiddleware { client: self.apply_middleware(existing.raw_client.clone()), redirect_policy: self.redirect_policy, cross_origin_credentials_policy: self.cross_origin_credential_policy, }; let dangerous_client = RedirectClientWithMiddleware { client: self.apply_middleware(existing.raw_dangerous_client.clone()), redirect_policy: self.redirect_policy, cross_origin_credentials_policy: self.cross_origin_credential_policy, }; BaseClient { connectivity: self.connectivity, allow_insecure_host: self.allow_insecure_host.clone(), retries: self.retries, client, dangerous_client, raw_client: existing.raw_client.clone(), raw_dangerous_client: existing.raw_dangerous_client.clone(), timeout: existing.timeout, } } fn create_secure_and_insecure_clients(&self, timeout: Duration) -> (Client, Client) { // Create user agent. let mut user_agent_string = format!("uv/{}", version()); // Add linehaul metadata. let linehaul = LineHaul::new(self.markers, self.platform, self.subcommand.clone()); if let Ok(output) = serde_json::to_string(&linehaul) { let _ = write!(user_agent_string, " {output}"); } // Checks for the presence of `SSL_CERT_FILE`. // Certificate loading support is delegated to `rustls-native-certs`. // See https://github.com/rustls/rustls-native-certs/blob/813790a297ad4399efe70a8e5264ca1b420acbec/src/lib.rs#L118-L125 let ssl_cert_file_exists = env::var_os(EnvVars::SSL_CERT_FILE).is_some_and(|path| { let path_exists = Path::new(&path).exists(); if !path_exists { warn_user_once!( "Ignoring invalid `SSL_CERT_FILE`. File does not exist: {}.", path.simplified_display().cyan() ); } path_exists }); // Checks for the presence of `SSL_CERT_DIR`. // Certificate loading support is delegated to `rustls-native-certs`. // See https://github.com/rustls/rustls-native-certs/blob/813790a297ad4399efe70a8e5264ca1b420acbec/src/lib.rs#L118-L125 let ssl_cert_dir_exists = env::var_os(EnvVars::SSL_CERT_DIR) .filter(|v| !v.is_empty()) .is_some_and(|dirs| { // Parse `SSL_CERT_DIR`, with support for multiple entries using // a platform-specific delimiter (`:` on Unix, `;` on Windows) let (existing, missing): (Vec<_>, Vec<_>) = env::split_paths(&dirs).partition(|p| p.exists()); if existing.is_empty() { let end_note = if missing.len() == 1 { "The directory does not exist." } else { "The entries do not exist." }; warn_user_once!( "Ignoring invalid `SSL_CERT_DIR`. {end_note}: {}.", missing .iter() .map(Simplified::simplified_display) .join(", ") .cyan() ); return false; } // Warn on any missing entries if !missing.is_empty() { let end_note = if missing.len() == 1 { "The following directory does not exist:" } else { "The following entries do not exist:" }; warn_user_once!( "Invalid entries in `SSL_CERT_DIR`. {end_note}: {}.", missing .iter() .map(Simplified::simplified_display) .join(", ") .cyan() ); } // Proceed while ignoring missing entries true }); // Create a secure client that validates certificates. let raw_client = self.create_client( &user_agent_string, timeout, ssl_cert_file_exists, ssl_cert_dir_exists, Security::Secure, self.redirect_policy, ); // Create an insecure client that accepts invalid certificates. let raw_dangerous_client = self.create_client( &user_agent_string, timeout, ssl_cert_file_exists, ssl_cert_dir_exists, Security::Insecure, self.redirect_policy, ); (raw_client, raw_dangerous_client) } fn create_client( &self, user_agent: &str, timeout: Duration, ssl_cert_file_exists: bool, ssl_cert_dir_exists: bool, security: Security, redirect_policy: RedirectPolicy, ) -> Client { // Configure the builder. let client_builder = ClientBuilder::new() .http1_title_case_headers() .user_agent(user_agent) .pool_max_idle_per_host(20) .read_timeout(timeout) .tls_built_in_root_certs(self.built_in_root_certs) .redirect(redirect_policy.reqwest_policy()); // If necessary, accept invalid certificates. let client_builder = match security { Security::Secure => client_builder, Security::Insecure => client_builder.danger_accept_invalid_certs(true), }; let client_builder = if self.native_tls || ssl_cert_file_exists || ssl_cert_dir_exists { client_builder.tls_built_in_native_certs(true) } else { client_builder.tls_built_in_webpki_certs(true) }; // Configure mTLS. let client_builder = if let Some(ssl_client_cert) = env::var_os(EnvVars::SSL_CLIENT_CERT) { match read_identity(&ssl_client_cert) { Ok(identity) => client_builder.identity(identity), Err(err) => { warn_user_once!("Ignoring invalid `SSL_CLIENT_CERT`: {err}"); client_builder } } } else { client_builder }; // apply proxies let mut client_builder = client_builder; for p in &self.proxies { client_builder = client_builder.proxy(p.clone()); } let client_builder = client_builder; client_builder .build() .expect("Failed to build HTTP client.") } fn apply_middleware(&self, client: Client) -> ClientWithMiddleware { match self.connectivity { Connectivity::Online => { // Create a base client to using in the authentication middleware. let base_client = { let mut client = reqwest_middleware::ClientBuilder::new(client.clone()); // Avoid uncloneable errors with a streaming body during publish. if self.retries > 0 { // Initialize the retry strategy. let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy( self.retry_policy(), UvRetryableStrategy, ); client = client.with(retry_strategy); } // When supplied, add the extra middleware. if let Some(extra_middleware) = &self.extra_middleware { for middleware in &extra_middleware.0 { client = client.with_arc(middleware.clone()); } } client.build() }; let mut client = reqwest_middleware::ClientBuilder::new(client); // Avoid uncloneable errors with a streaming body during publish. if self.retries > 0 { // Initialize the retry strategy. let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy( self.retry_policy(), UvRetryableStrategy, ); client = client.with(retry_strategy); } // When supplied, add the extra middleware. if let Some(extra_middleware) = &self.extra_middleware { for middleware in &extra_middleware.0 { client = client.with_arc(middleware.clone()); } } // Initialize the authentication middleware to set headers. match self.auth_integration { AuthIntegration::Default => { let mut auth_middleware = AuthMiddleware::new() .with_base_client(base_client) .with_indexes(self.indexes.clone()) .with_keyring(self.keyring.to_provider()) .with_preview(self.preview); if let Ok(token_store) = PyxTokenStore::from_settings() { auth_middleware = auth_middleware.with_pyx_token_store(token_store); } client = client.with(auth_middleware); } AuthIntegration::OnlyAuthenticated => { let mut auth_middleware = AuthMiddleware::new() .with_base_client(base_client) .with_indexes(self.indexes.clone()) .with_keyring(self.keyring.to_provider()) .with_preview(self.preview) .with_only_authenticated(true); if let Ok(token_store) = PyxTokenStore::from_settings() { auth_middleware = auth_middleware.with_pyx_token_store(token_store); } client = client.with(auth_middleware); } AuthIntegration::NoAuthMiddleware => { // The downstream code uses custom auth logic. } } client.build() } Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client) .with(OfflineMiddleware) .build(), } } } /// A base client for HTTP requests #[derive(Debug, Clone)] pub struct BaseClient { /// The underlying HTTP client that enforces valid certificates. client: RedirectClientWithMiddleware, /// The underlying HTTP client that accepts invalid certificates. dangerous_client: RedirectClientWithMiddleware, /// The HTTP client without middleware. raw_client: Client, /// The HTTP client that accepts invalid certificates without middleware. raw_dangerous_client: Client, /// The connectivity mode to use. connectivity: Connectivity, /// Configured client timeout, in seconds. timeout: Duration, /// Hosts that are trusted to use the insecure client. allow_insecure_host: Vec, /// The number of retries to attempt on transient errors. retries: u32, } #[derive(Debug, Clone, Copy)] enum Security { /// The client should use secure settings, i.e., valid certificates. Secure, /// The client should use insecure settings, i.e., skip certificate validation. Insecure, } impl BaseClient { /// Selects the appropriate client based on the host's trustworthiness. pub fn for_host(&self, url: &DisplaySafeUrl) -> &RedirectClientWithMiddleware { if self.disable_ssl(url) { &self.dangerous_client } else { &self.client } } /// Executes a request, applying redirect policy. pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { let client = self.for_host(&DisplaySafeUrl::from_url(req.url().clone())); client.execute(req).await } /// Returns `true` if the host is trusted to use the insecure client. pub fn disable_ssl(&self, url: &DisplaySafeUrl) -> bool { self.allow_insecure_host .iter() .any(|allow_insecure_host| allow_insecure_host.matches(url)) } /// The configured client timeout, in seconds. pub fn timeout(&self) -> Duration { self.timeout } /// The configured connectivity mode. pub fn connectivity(&self) -> Connectivity { self.connectivity } /// The [`RetryPolicy`] for the client. pub fn retry_policy(&self) -> ExponentialBackoff { let mut builder = ExponentialBackoff::builder(); if env::var_os(EnvVars::UV_TEST_NO_HTTP_RETRY_DELAY).is_some() { builder = builder.retry_bounds(Duration::from_millis(0), Duration::from_millis(0)); } builder.build_with_max_retries(self.retries) } } /// Wrapper around [`ClientWithMiddleware`] that manages redirects. #[derive(Debug, Clone)] pub struct RedirectClientWithMiddleware { client: ClientWithMiddleware, redirect_policy: RedirectPolicy, /// Whether credentials should be preserved during cross-origin redirects. /// /// WARNING: This should only be available for tests. In production code, preserving credentials /// during cross-origin redirects can lead to security vulnerabilities including credential /// leakage to untrusted domains. cross_origin_credentials_policy: CrossOriginCredentialsPolicy, } impl RedirectClientWithMiddleware { /// Convenience method to make a `GET` request to a URL. pub fn get(&self, url: U) -> RequestBuilder<'_> { RequestBuilder::new(self.client.get(url), self) } /// Convenience method to make a `POST` request to a URL. pub fn post(&self, url: U) -> RequestBuilder<'_> { RequestBuilder::new(self.client.post(url), self) } /// Convenience method to make a `HEAD` request to a URL. pub fn head(&self, url: U) -> RequestBuilder<'_> { RequestBuilder::new(self.client.head(url), self) } /// Executes a request, applying the redirect policy. pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { match self.redirect_policy { RedirectPolicy::BypassMiddleware => self.client.execute(req).await, RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await, } } /// Executes a request. If the response is a redirect (one of HTTP 301, 302, 303, 307, or 308), the /// request is executed again with the redirect location URL (up to a maximum number of /// redirects). /// /// Unlike the built-in reqwest redirect policies, this sends the redirect request through the /// entire middleware pipeline again. /// /// See RFC 7231 7.1.2 for details on /// redirect semantics. async fn execute_with_redirect_handling( &self, req: Request, ) -> reqwest_middleware::Result { let mut request = req; let mut redirects = 0; let max_redirects = DEFAULT_MAX_REDIRECTS; loop { let result = self .client .execute(request.try_clone().expect("HTTP request must be cloneable")) .await; let Ok(response) = result else { return result; }; if redirects >= max_redirects { return Ok(response); } let Some(redirect_request) = request_into_redirect(request, &response, self.cross_origin_credentials_policy)? else { return Ok(response); }; redirects += 1; request = redirect_request; } } pub fn raw_client(&self) -> &ClientWithMiddleware { &self.client } } impl From for ClientWithMiddleware { fn from(item: RedirectClientWithMiddleware) -> Self { item.client } } /// Check if this is should be a redirect and, if so, return a new redirect request. /// /// This implementation is based on the [`reqwest`] crate redirect implementation. /// It takes ownership of the original [`Request`] and mutates it to create the new /// redirect [`Request`]. fn request_into_redirect( mut req: Request, res: &Response, cross_origin_credentials_policy: CrossOriginCredentialsPolicy, ) -> reqwest_middleware::Result> { let original_req_url = DisplaySafeUrl::from_url(req.url().clone()); let status = res.status(); let should_redirect = match status { StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => true, StatusCode::SEE_OTHER => { // Per RFC 7231, HTTP 303 is intended for the user agent // to perform a GET or HEAD request to the redirect target. // Historically, some browsers also changed method from POST // to GET on 301 or 302, but this is not required by RFC 7231 // and was not intended by the HTTP spec. *req.body_mut() = None; for header in &[ TRANSFER_ENCODING, CONTENT_ENCODING, CONTENT_TYPE, CONTENT_LENGTH, ] { req.headers_mut().remove(header); } match *req.method() { Method::GET | Method::HEAD => {} _ => { *req.method_mut() = Method::GET; } } true } _ => false, }; if !should_redirect { return Ok(None); } let location = res .headers() .get(LOCATION) .ok_or(reqwest_middleware::Error::Middleware(anyhow!( "Server returned redirect (HTTP {status}) without destination URL. This may indicate a server configuration issue" )))? .to_str() .map_err(|_| { reqwest_middleware::Error::Middleware(anyhow!( "Invalid HTTP {status} 'Location' value: must only contain visible ascii characters" )) })?; let mut redirect_url = match DisplaySafeUrl::parse(location) { Ok(url) => url, // Per RFC 7231, URLs should be resolved against the request URL. Err(DisplaySafeUrlError::Url(ParseError::RelativeUrlWithoutBase)) => original_req_url.join(location).map_err(|err| { reqwest_middleware::Error::Middleware(anyhow!( "Invalid HTTP {status} 'Location' value `{location}` relative to `{original_req_url}`: {err}" )) })?, Err(err) => { return Err(reqwest_middleware::Error::Middleware(anyhow!( "Invalid HTTP {status} 'Location' value `{location}`: {err}" ))); } }; // Per RFC 7231, fragments must be propagated if let Some(fragment) = original_req_url.fragment() { redirect_url.set_fragment(Some(fragment)); } // Ensure the URL is a valid HTTP URI. if let Err(err) = redirect_url.as_str().parse::() { return Err(reqwest_middleware::Error::Middleware(anyhow!( "HTTP {status} 'Location' value `{redirect_url}` is not a valid HTTP URI: {err}" ))); } if redirect_url.scheme() != "http" && redirect_url.scheme() != "https" { return Err(reqwest_middleware::Error::Middleware(anyhow!( "Invalid HTTP {status} 'Location' value `{redirect_url}`: scheme needs to be https or http" ))); } let mut headers = HeaderMap::new(); std::mem::swap(req.headers_mut(), &mut headers); let cross_host = redirect_url.host_str() != original_req_url.host_str() || redirect_url.port_or_known_default() != original_req_url.port_or_known_default(); if cross_host { if cross_origin_credentials_policy == CrossOriginCredentialsPolicy::Secure { debug!("Received a cross-origin redirect. Removing sensitive headers."); headers.remove(AUTHORIZATION); headers.remove(COOKIE); headers.remove(PROXY_AUTHORIZATION); headers.remove(WWW_AUTHENTICATE); } // If the redirect request is not a cross-origin request and the original request already // had a Referer header, attempt to set the Referer header for the redirect request. } else if headers.contains_key(REFERER) { if let Some(referer) = make_referer(&redirect_url, &original_req_url) { headers.insert(REFERER, referer); } } // Check if there are credentials on the redirect location itself. // If so, move them to Authorization header. if !redirect_url.username().is_empty() { if let Some(credentials) = Credentials::from_url(&redirect_url) { let _ = redirect_url.set_username(""); let _ = redirect_url.set_password(None); headers.insert(AUTHORIZATION, credentials.to_header_value()); } } std::mem::swap(req.headers_mut(), &mut headers); *req.url_mut() = Url::from(redirect_url); debug!( "Received HTTP {status}. Redirecting to {}", DisplaySafeUrl::ref_cast(req.url()) ); Ok(Some(req)) } /// Return a Referer [`HeaderValue`] according to RFC 7231. /// /// Return [`None`] if https has been downgraded in the redirect location. fn make_referer( redirect_url: &DisplaySafeUrl, original_url: &DisplaySafeUrl, ) -> Option { if redirect_url.scheme() == "http" && original_url.scheme() == "https" { return None; } let mut referer = original_url.clone(); referer.remove_credentials(); referer.set_fragment(None); referer.as_str().parse().ok() } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub(crate) enum CrossOriginCredentialsPolicy { /// Do not propagate credentials on cross-origin requests. #[default] Secure, /// Propagate credentials on cross-origin requests. /// /// WARNING: This should only be available for tests. In production code, preserving credentials /// during cross-origin redirects can lead to security vulnerabilities including credential /// leakage to untrusted domains. #[cfg(test)] Insecure, } /// A builder to construct the properties of a `Request`. /// /// This wraps [`reqwest_middleware::RequestBuilder`] to ensure that the [`BaseClient`] /// redirect policy is respected if `send()` is called. #[derive(Debug)] #[must_use] pub struct RequestBuilder<'a> { builder: reqwest_middleware::RequestBuilder, client: &'a RedirectClientWithMiddleware, } impl<'a> RequestBuilder<'a> { pub fn new( builder: reqwest_middleware::RequestBuilder, client: &'a RedirectClientWithMiddleware, ) -> Self { Self { builder, client } } /// Add a `Header` to this Request. pub fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, HeaderValue: TryFrom, >::Error: Into, { self.builder = self.builder.header(key, value); self } /// Add a set of Headers to the existing ones on this Request. /// /// The headers will be merged in to any already set. pub fn headers(mut self, headers: HeaderMap) -> Self { self.builder = self.builder.headers(headers); self } #[cfg(not(target_arch = "wasm32"))] pub fn version(mut self, version: reqwest::Version) -> Self { self.builder = self.builder.version(version); self } #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] pub fn multipart(mut self, multipart: multipart::Form) -> Self { self.builder = self.builder.multipart(multipart); self } /// Build a `Request`. pub fn build(self) -> reqwest::Result { self.builder.build() } /// Constructs the Request and sends it to the target URL, returning a /// future Response. pub async fn send(self) -> reqwest_middleware::Result { self.client.execute(self.build()?).await } pub fn raw_builder(&self) -> &reqwest_middleware::RequestBuilder { &self.builder } } /// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases. pub struct UvRetryableStrategy; impl RetryableStrategy for UvRetryableStrategy { fn handle(&self, res: &Result) -> Option { // Use the default strategy and check for additional transient error cases. let retryable = match DefaultRetryableStrategy.handle(res) { None | Some(Retryable::Fatal) if res .as_ref() .is_err_and(|err| is_transient_network_error(err)) => { Some(Retryable::Transient) } default => default, }; // Log on transient errors if retryable == Some(Retryable::Transient) { match res { Ok(response) => { debug!("Transient request failure for: {}", response.url()); } Err(err) => { let context = iter::successors(err.source(), |&err| err.source()) .map(|err| format!(" Caused by: {err}")) .join("\n"); debug!( "Transient request failure for {}, retrying: {err}\n{context}", err.url().map(Url::as_str).unwrap_or("unknown URL") ); } } } retryable } } /// Whether the error looks like a network error that should be retried. /// /// There are two cases that the default retry strategy is missing: /// * Inside the reqwest or reqwest-middleware error is an `io::Error` such as a broken pipe /// * When streaming a response, a reqwest error may be hidden several layers behind errors /// of different crates processing the stream, including `io::Error` layers. pub fn is_transient_network_error(err: &(dyn Error + 'static)) -> bool { // First, try to show a nice trace log if let Some((Some(status), Some(url))) = find_source::(&err) .map(|request_err| (request_err.status(), request_err.url())) { trace!("Considering retry of response HTTP {status} for {url}"); } else { trace!("Considering retry of error: {err:?}"); } let mut has_known_error = false; // IO Errors or reqwest errors may be nested through custom IO errors or stream processing // crates let mut current_source = Some(err); while let Some(source) = current_source { if let Some(reqwest_err) = source.downcast_ref::() { has_known_error = true; if let reqwest_middleware::Error::Reqwest(reqwest_err) = &**reqwest_err { if default_on_request_error(reqwest_err) == Some(Retryable::Transient) { trace!("Retrying nested reqwest middleware error"); return true; } if is_retryable_status_error(reqwest_err) { trace!("Retrying nested reqwest middleware status code error"); return true; } } trace!("Cannot retry nested reqwest middleware error"); } else if let Some(reqwest_err) = source.downcast_ref::() { has_known_error = true; if default_on_request_error(reqwest_err) == Some(Retryable::Transient) { trace!("Retrying nested reqwest error"); return true; } if is_retryable_status_error(reqwest_err) { trace!("Retrying nested reqwest status code error"); return true; } trace!("Cannot retry nested reqwest error"); } else if source.downcast_ref::().is_some() { // All h2 errors look like errors that should be retried // https://github.com/astral-sh/uv/issues/15916 trace!("Retrying nested h2 error"); return true; } else if let Some(io_err) = source.downcast_ref::() { has_known_error = true; let retryable_io_err_kinds = [ // https://github.com/astral-sh/uv/issues/12054 io::ErrorKind::BrokenPipe, // From reqwest-middleware io::ErrorKind::ConnectionAborted, // https://github.com/astral-sh/uv/issues/3514 io::ErrorKind::ConnectionReset, // https://github.com/astral-sh/uv/issues/14699 io::ErrorKind::InvalidData, // https://github.com/astral-sh/uv/issues/9246 io::ErrorKind::UnexpectedEof, ]; if retryable_io_err_kinds.contains(&io_err.kind()) { trace!("Retrying error: `{}`", io_err.kind()); return true; } trace!( "Cannot retry IO error `{}`, not a retryable IO error kind", io_err.kind() ); } current_source = source.source(); } if !has_known_error { trace!("Cannot retry error: Neither an IO error nor a reqwest error"); } false } /// Whether the error is a status code error that is retryable. /// /// Port of `reqwest_retry::default_on_request_success`. fn is_retryable_status_error(reqwest_err: &reqwest::Error) -> bool { let Some(status) = reqwest_err.status() else { return false; }; status.is_server_error() || status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS } /// Find the first source error of a specific type. /// /// See fn find_source(orig: &dyn Error) -> Option<&E> { let mut cause = orig.source(); while let Some(err) = cause { if let Some(typed) = err.downcast_ref() { return Some(typed); } cause = err.source(); } None } // TODO(konsti): Remove once we find a native home for `retries_from_env` #[derive(Debug, Error)] pub enum RetryParsingError { #[error("Failed to parse `UV_HTTP_RETRIES`")] ParseInt(#[from] ParseIntError), } #[cfg(test)] mod tests { use super::*; use anyhow::Result; use insta::assert_debug_snapshot; use reqwest::{Client, Method}; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; use crate::base_client::request_into_redirect; #[tokio::test] async fn test_redirect_preserves_authorization_header_on_same_origin() -> Result<()> { for status in &[301, 302, 303, 307, 308] { let server = MockServer::start().await; Mock::given(method("GET")) .respond_with( ResponseTemplate::new(*status) .insert_header("location", format!("{}/redirect", server.uri())), ) .mount(&server) .await; let request = Client::new() .get(server.uri()) .basic_auth("username", Some("password")) .build() .unwrap(); assert!(request.headers().contains_key(AUTHORIZATION)); let response = Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() .unwrap() .execute(request.try_clone().unwrap()) .await .unwrap(); let redirect_request = request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? .unwrap(); assert!(redirect_request.headers().contains_key(AUTHORIZATION)); } Ok(()) } #[tokio::test] async fn test_redirect_preserves_fragment() -> Result<()> { for status in &[301, 302, 303, 307, 308] { let server = MockServer::start().await; Mock::given(method("GET")) .respond_with( ResponseTemplate::new(*status) .insert_header("location", format!("{}/redirect", server.uri())), ) .mount(&server) .await; let request = Client::new() .get(format!("{}#fragment", server.uri())) .build() .unwrap(); let response = Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() .unwrap() .execute(request.try_clone().unwrap()) .await .unwrap(); let redirect_request = request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? .unwrap(); assert!( redirect_request .url() .fragment() .is_some_and(|fragment| fragment == "fragment") ); } Ok(()) } #[tokio::test] async fn test_redirect_removes_authorization_header_on_cross_origin() -> Result<()> { for status in &[301, 302, 303, 307, 308] { let server = MockServer::start().await; Mock::given(method("GET")) .respond_with( ResponseTemplate::new(*status) .insert_header("location", "https://cross-origin.com/simple"), ) .mount(&server) .await; let request = Client::new() .get(server.uri()) .basic_auth("username", Some("password")) .build() .unwrap(); assert!(request.headers().contains_key(AUTHORIZATION)); let response = Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() .unwrap() .execute(request.try_clone().unwrap()) .await .unwrap(); let redirect_request = request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? .unwrap(); assert!(!redirect_request.headers().contains_key(AUTHORIZATION)); } Ok(()) } #[tokio::test] async fn test_redirect_303_changes_post_to_get() -> Result<()> { let server = MockServer::start().await; Mock::given(method("POST")) .respond_with( ResponseTemplate::new(303) .insert_header("location", format!("{}/redirect", server.uri())), ) .mount(&server) .await; let request = Client::new() .post(server.uri()) .basic_auth("username", Some("password")) .build() .unwrap(); assert_eq!(request.method(), Method::POST); let response = Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() .unwrap() .execute(request.try_clone().unwrap()) .await .unwrap(); let redirect_request = request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? .unwrap(); assert_eq!(redirect_request.method(), Method::GET); Ok(()) } #[tokio::test] async fn test_redirect_no_referer_if_disabled() -> Result<()> { for status in &[301, 302, 303, 307, 308] { let server = MockServer::start().await; Mock::given(method("GET")) .respond_with( ResponseTemplate::new(*status) .insert_header("location", format!("{}/redirect", server.uri())), ) .mount(&server) .await; let request = Client::builder() .referer(false) .build() .unwrap() .get(server.uri()) .basic_auth("username", Some("password")) .build() .unwrap(); assert!(!request.headers().contains_key(REFERER)); let response = Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() .unwrap() .execute(request.try_clone().unwrap()) .await .unwrap(); let redirect_request = request_into_redirect(request, &response, CrossOriginCredentialsPolicy::Secure)? .unwrap(); assert!(!redirect_request.headers().contains_key(REFERER)); } Ok(()) } /// Enumerate which status codes we are retrying. #[tokio::test] async fn retried_status_codes() -> Result<()> { let server = MockServer::start().await; let client = Client::default(); let middleware_client = ClientWithMiddleware::default(); let mut retried = Vec::new(); for status in 100..599 { // Test all standard status codes and and example for a non-RFC code used in the wild. if StatusCode::from_u16(status)?.canonical_reason().is_none() && status != 420 { continue; } Mock::given(path(format!("/{status}"))) .respond_with(ResponseTemplate::new(status)) .mount(&server) .await; let response = middleware_client .get(format!("{}/{}", server.uri(), status)) .send() .await; let middleware_retry = DefaultRetryableStrategy.handle(&response) == Some(Retryable::Transient); let response = client .get(format!("{}/{}", server.uri(), status)) .send() .await?; let uv_retry = match response.error_for_status() { Ok(_) => false, Err(err) => is_transient_network_error(&err), }; // Ensure we're retrying the same status code as the reqwest_retry crate. We may choose // to deviate from this later. assert_eq!(middleware_retry, uv_retry); if uv_retry { retried.push(status); } } assert_debug_snapshot!(retried, @r" [ 100, 102, 408, 429, 500, 501, 502, 503, 504, 505, 506, 507, 508, 510, 511, ] "); Ok(()) } }