From 4ee4a8861e0968d24bda70720f3fe685bd6c55a6 Mon Sep 17 00:00:00 2001 From: John Mumm Date: Mon, 28 Apr 2025 09:07:06 +0200 Subject: [PATCH] Implement RFC 7231 compliant relative URI and fragment handling in redirects (#13050) This PR restores #13041 and integrates two PRs from @zanieb: * #13038 * #13040 It also adds tests for relative URI and fragment handling. Closes #13037. --------- Co-authored-by: Zanie Blue --- Cargo.lock | 1 + crates/uv-client/Cargo.toml | 1 + crates/uv-client/src/base_client.rs | 267 ++++++++++++++++++++++- crates/uv-client/src/cached_client.rs | 2 - crates/uv-client/src/lib.rs | 2 +- crates/uv-client/src/registry_client.rs | 220 ++++++++++++++++++- crates/uv-distribution/src/source/mod.rs | 7 +- crates/uv-git/src/resolver.rs | 4 +- crates/uv-publish/src/lib.rs | 29 ++- crates/uv/tests/it/common/mod.rs | 3 +- crates/uv/tests/it/edit.rs | 197 ++++++++++++++++- 11 files changed, 696 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bb496981c..07919dae9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4934,6 +4934,7 @@ dependencies = [ "uv-torch", "uv-version", "uv-warnings", + "wiremock", ] [[package]] diff --git a/crates/uv-client/Cargo.toml b/crates/uv-client/Cargo.toml index 64088676e..f9ef3220f 100644 --- a/crates/uv-client/Cargo.toml +++ b/crates/uv-client/Cargo.toml @@ -64,3 +64,4 @@ hyper = { version = "1.4.1", features = ["server", "http1"] } hyper-util = { version = "0.1.8", features = ["tokio"] } insta = { version = "1.40.0", features = ["filters", "json", "redactions"] } tokio = { workspace = true } +wiremock = { workspace = true } diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs index 071779fbe..0297c1721 100644 --- a/crates/uv-client/src/base_client.rs +++ b/crates/uv-client/src/base_client.rs @@ -6,14 +6,17 @@ use std::sync::Arc; use std::time::Duration; use std::{env, iter}; +use anyhow::anyhow; +use http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; use itertools::Itertools; -use reqwest::{Client, ClientBuilder, Proxy, Response}; +use reqwest::{multipart, Client, ClientBuilder, IntoUrl, Proxy, Request, Response}; use reqwest_middleware::{ClientWithMiddleware, Middleware}; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{ DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy, }; use tracing::{debug, trace}; +use url::ParseError; use url::Url; use uv_auth::{AuthMiddleware, UrlAuthPolicies}; @@ -60,6 +63,24 @@ pub struct BaseClientBuilder<'a> { default_timeout: Duration, extra_middleware: Option, proxies: Vec, + redirect_policy: RedirectPolicy, +} + +/// The policy for handling redirects. +#[derive(Debug, Default, Clone, Copy)] +pub enum RedirectPolicy { + #[default] + BypassMiddleware, + RetriggerMiddleware, +} + +impl RedirectPolicy { + pub fn reqwest_policy(self) -> reqwest::redirect::Policy { + match self { + RedirectPolicy::BypassMiddleware => reqwest::redirect::Policy::default(), + RedirectPolicy::RetriggerMiddleware => reqwest::redirect::Policy::none(), + } + } } /// A list of user-defined middlewares to be applied to the client. @@ -95,6 +116,7 @@ impl BaseClientBuilder<'_> { default_timeout: Duration::from_secs(30), extra_middleware: None, proxies: vec![], + redirect_policy: RedirectPolicy::default(), } } } @@ -172,6 +194,12 @@ impl<'a> BaseClientBuilder<'a> { self } + #[must_use] + pub fn redirect(mut self, policy: RedirectPolicy) -> Self { + self.redirect_policy = policy; + self + } + pub fn is_offline(&self) -> bool { matches!(self.connectivity, Connectivity::Offline) } @@ -228,6 +256,7 @@ impl<'a> BaseClientBuilder<'a> { timeout, ssl_cert_file_exists, Security::Secure, + self.redirect_policy, ); // Create an insecure client that accepts invalid certificates. @@ -236,11 +265,18 @@ impl<'a> BaseClientBuilder<'a> { timeout, ssl_cert_file_exists, Security::Insecure, + self.redirect_policy, ); // Wrap in any relevant middleware and handle connectivity. - let client = self.apply_middleware(raw_client.clone()); - let dangerous_client = self.apply_middleware(raw_dangerous_client.clone()); + let client = RedirectClientWithMiddleware { + client: self.apply_middleware(raw_client.clone()), + redirect_policy: self.redirect_policy, + }; + let dangerous_client = RedirectClientWithMiddleware { + client: self.apply_middleware(raw_dangerous_client.clone()), + redirect_policy: self.redirect_policy, + }; BaseClient { connectivity: self.connectivity, @@ -257,8 +293,14 @@ impl<'a> BaseClientBuilder<'a> { /// Share the underlying client between two different middleware configurations. pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient { // Wrap in any relevant middleware and handle connectivity. - let client = self.apply_middleware(existing.raw_client.clone()); - let dangerous_client = self.apply_middleware(existing.raw_dangerous_client.clone()); + let client = RedirectClientWithMiddleware { + client: self.apply_middleware(existing.raw_client.clone()), + redirect_policy: self.redirect_policy, + }; + let dangerous_client = RedirectClientWithMiddleware { + client: self.apply_middleware(existing.raw_dangerous_client.clone()), + redirect_policy: self.redirect_policy, + }; BaseClient { connectivity: self.connectivity, @@ -278,6 +320,7 @@ impl<'a> BaseClientBuilder<'a> { timeout: Duration, ssl_cert_file_exists: bool, security: Security, + redirect_policy: RedirectPolicy, ) -> Client { // Configure the builder. let client_builder = ClientBuilder::new() @@ -285,7 +328,8 @@ impl<'a> BaseClientBuilder<'a> { .user_agent(user_agent) .pool_max_idle_per_host(20) .read_timeout(timeout) - .tls_built_in_root_certs(false); + .tls_built_in_root_certs(false) + .redirect(redirect_policy.reqwest_policy()); // If necessary, accept invalid certificates. let client_builder = match security { @@ -382,9 +426,9 @@ impl<'a> BaseClientBuilder<'a> { #[derive(Debug, Clone)] pub struct BaseClient { /// The underlying HTTP client that enforces valid certificates. - client: ClientWithMiddleware, + client: RedirectClientWithMiddleware, /// The underlying HTTP client that accepts invalid certificates. - dangerous_client: ClientWithMiddleware, + dangerous_client: RedirectClientWithMiddleware, /// The HTTP client without middleware. raw_client: Client, /// The HTTP client that accepts invalid certificates without middleware. @@ -409,7 +453,7 @@ enum Security { impl BaseClient { /// Selects the appropriate client based on the host's trustworthiness. - pub fn for_host(&self, url: &Url) -> &ClientWithMiddleware { + pub fn for_host(&self, url: &Url) -> &RedirectClientWithMiddleware { if self.disable_ssl(url) { &self.dangerous_client } else { @@ -417,6 +461,12 @@ impl BaseClient { } } + /// Executes a request, applying redirect policy. + pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { + let client = self.for_host(req.url()); + client.execute(req).await + } + /// Returns `true` if the host is trusted to use the insecure client. pub fn disable_ssl(&self, url: &Url) -> bool { self.allow_insecure_host @@ -440,6 +490,205 @@ impl BaseClient { } } +/// Wrapper around [`ClientWithMiddleware`] that manages redirects. +#[derive(Debug, Clone)] +pub struct RedirectClientWithMiddleware { + client: ClientWithMiddleware, + redirect_policy: RedirectPolicy, +} + +impl RedirectClientWithMiddleware { + /// Convenience method to make a `GET` request to a URL. + pub fn get(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.get(url), self) + } + + /// Convenience method to make a `POST` request to a URL. + pub fn post(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.post(url), self) + } + + /// Convenience method to make a `HEAD` request to a URL. + pub fn head(&self, url: U) -> RequestBuilder { + RequestBuilder::new(self.client.head(url), self) + } + + /// Executes a request, applying the redirect policy. + pub async fn execute(&self, req: Request) -> reqwest_middleware::Result { + match self.redirect_policy { + RedirectPolicy::BypassMiddleware => self.client.execute(req).await, + RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await, + } + } + + /// Executes a request. If the response is a redirect (one of HTTP 301, 302, 307, or 308), the + /// request is executed again with the redirect location URL (up to a maximum number of + /// redirects). + /// + /// Unlike the built-in reqwest redirect policies, this sends the redirect request through the + /// entire middleware pipeline again. + /// + /// See RFC 7231 7.1.2 for details on + /// redirect semantics. + async fn execute_with_redirect_handling( + &self, + req: Request, + ) -> reqwest_middleware::Result { + let mut request = req; + let mut redirects = 0; + // This is the default used by reqwest. + let max_redirects = 10; + + loop { + let request_url = request.url().clone(); + let result = self + .client + .execute(request.try_clone().expect("HTTP request must be cloneable")) + .await; + if redirects == max_redirects { + return result; + } + let Ok(response) = result else { + return result; + }; + + // Handle redirect if we receive a 301, 302, 307, or 308. + let status = response.status(); + if matches!( + status, + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::TEMPORARY_REDIRECT + | StatusCode::PERMANENT_REDIRECT + ) { + let location = response + .headers() + .get("location") + .ok_or(reqwest_middleware::Error::Middleware(anyhow!( + "Missing expected HTTP {status} 'Location' header" + )))? + .to_str() + .map_err(|_| { + reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value: must only contain visible ascii characters" + )) + })?; + + let mut redirect_url = match Url::parse(location) { + Ok(url) => url, + // Per RFC 7231, URLs should be resolved against the request URL. + Err(ParseError::RelativeUrlWithoutBase) => request_url.join(location).map_err(|err| { + reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{location}` relative to `{request_url}`: {err}" + )) + })?, + Err(err) => { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{location}`: {err}" + ))); + } + }; + + // Ensure the URL is a valid HTTP URI. + if let Err(err) = redirect_url.as_str().parse::() { + return Err(reqwest_middleware::Error::Middleware(anyhow!( + "Invalid HTTP {status} 'Location' value `{location}`: {err}" + ))); + } + + // Per RFC 7231, fragments must be propagated + if let Some(fragment) = request_url.fragment() { + redirect_url.set_fragment(Some(fragment)); + } + + debug!("Received HTTP {status} to {redirect_url}"); + *request.url_mut() = redirect_url; + redirects += 1; + continue; + } + + return Ok(response); + } + } + + pub fn raw_client(&self) -> &ClientWithMiddleware { + &self.client + } +} + +impl From for ClientWithMiddleware { + fn from(item: RedirectClientWithMiddleware) -> ClientWithMiddleware { + item.client + } +} + +/// A builder to construct the properties of a `Request`. +/// +/// This wraps [`reqwest_middleware::RequestBuilder`] to ensure that the [`BaseClient`] +/// redirect policy is respected if `send()` is called. +#[derive(Debug)] +#[must_use] +pub struct RequestBuilder<'a> { + builder: reqwest_middleware::RequestBuilder, + client: &'a RedirectClientWithMiddleware, +} + +impl<'a> RequestBuilder<'a> { + pub fn new( + builder: reqwest_middleware::RequestBuilder, + client: &'a RedirectClientWithMiddleware, + ) -> Self { + Self { builder, client } + } + + /// Add a `Header` to this Request. + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, + { + self.builder = self.builder.header(key, value); + self + } + + /// Add a set of Headers to the existing ones on this Request. + /// + /// The headers will be merged in to any already set. + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.builder = self.builder.headers(headers); + self + } + + #[cfg(not(target_arch = "wasm32"))] + pub fn version(mut self, version: reqwest::Version) -> Self { + self.builder = self.builder.version(version); + self + } + + #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] + pub fn multipart(mut self, multipart: multipart::Form) -> Self { + self.builder = self.builder.multipart(multipart); + self + } + + /// Build a `Request`. + pub fn build(self) -> reqwest::Result { + self.builder.build() + } + + /// Constructs the Request and sends it to the target URL, returning a + /// future Response. + pub async fn send(self) -> reqwest_middleware::Result { + self.client.execute(self.build()?).await + } + + pub fn raw_builder(&self) -> &reqwest_middleware::RequestBuilder { + &self.builder + } +} + /// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases. pub struct UvRetryableStrategy; diff --git a/crates/uv-client/src/cached_client.rs b/crates/uv-client/src/cached_client.rs index 3ee9e1cfc..d385a6423 100644 --- a/crates/uv-client/src/cached_client.rs +++ b/crates/uv-client/src/cached_client.rs @@ -510,7 +510,6 @@ impl CachedClient { debug!("Sending revalidation request for: {url}"); let response = self .0 - .for_host(req.url()) .execute(req) .instrument(info_span!("revalidation_request", url = url.as_str())) .await @@ -551,7 +550,6 @@ impl CachedClient { let cache_policy_builder = CachePolicyBuilder::new(&req); let response = self .0 - .for_host(&url) .execute(req) .await .map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))? diff --git a/crates/uv-client/src/lib.rs b/crates/uv-client/src/lib.rs index 49ee1d955..6d1266ffe 100644 --- a/crates/uv-client/src/lib.rs +++ b/crates/uv-client/src/lib.rs @@ -1,6 +1,6 @@ pub use base_client::{ is_extended_transient_error, AuthIntegration, BaseClient, BaseClientBuilder, ExtraMiddleware, - UvRetryableStrategy, DEFAULT_RETRIES, + RedirectClientWithMiddleware, RequestBuilder, UvRetryableStrategy, DEFAULT_RETRIES, }; pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy}; pub use error::{Error, ErrorKind, WrappedReqwestError}; diff --git a/crates/uv-client/src/registry_client.rs b/crates/uv-client/src/registry_client.rs index 9450352a1..5aa9b230e 100644 --- a/crates/uv-client/src/registry_client.rs +++ b/crates/uv-client/src/registry_client.rs @@ -10,7 +10,6 @@ use futures::{FutureExt, StreamExt, TryStreamExt}; use http::HeaderMap; use itertools::Either; use reqwest::{Proxy, Response, StatusCode}; -use reqwest_middleware::ClientWithMiddleware; use rustc_hash::FxHashMap; use tokio::sync::{Mutex, Semaphore}; use tracing::{info_span, instrument, trace, warn, Instrument}; @@ -34,7 +33,7 @@ use uv_pypi_types::{ResolutionMetadata, SimpleJson}; use uv_small_str::SmallString; use uv_torch::TorchStrategy; -use crate::base_client::{BaseClientBuilder, ExtraMiddleware}; +use crate::base_client::{BaseClientBuilder, ExtraMiddleware, RedirectPolicy}; use crate::cached_client::CacheControl; use crate::flat_index::FlatIndexEntry; use crate::html::SimpleHtml; @@ -42,7 +41,7 @@ use crate::remote_metadata::wheel_metadata_from_remote_zip; use crate::rkyvutil::OwnedArchive; use crate::{ BaseClient, CachedClient, CachedClientError, Error, ErrorKind, FlatIndexClient, - FlatIndexEntries, + FlatIndexEntries, RedirectClientWithMiddleware, }; /// A builder for an [`RegistryClient`]. @@ -158,7 +157,9 @@ impl<'a> RegistryClientBuilder<'a> { pub fn build(self) -> RegistryClient { // Build a base client - let builder = self.base_client_builder; + let builder = self + .base_client_builder + .redirect(RedirectPolicy::RetriggerMiddleware); let client = builder.build(); @@ -255,7 +256,7 @@ impl RegistryClient { } /// Return the [`BaseClient`] used by this client. - pub fn uncached_client(&self, url: &Url) -> &ClientWithMiddleware { + pub fn uncached_client(&self, url: &Url) -> &RedirectClientWithMiddleware { self.client.uncached().for_host(url) } @@ -1175,6 +1176,215 @@ mod tests { use crate::{html::SimpleHtml, SimpleMetadata, SimpleMetadatum}; + use uv_cache::Cache; + use wiremock::matchers::{basic_auth, method, path_regex}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::RegistryClientBuilder; + + type Error = Box; + + async fn start_test_server(username: &'static str, password: &'static str) -> MockServer { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(401)) + .mount(&server) + .await; + + server + } + + #[tokio::test] + async fn test_redirect_to_server_with_credentials() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let auth_server = start_test_server(username, password).await; + let auth_base_url = Url::parse(&auth_server.uri())?; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 302 to the auth server + Mock::given(method("GET")) + .respond_with( + ResponseTemplate::new(302).insert_header("Location", format!("{}", &auth_base_url)), + ) + .mount(&redirect_server) + .await; + + let redirect_server_url = Url::parse(&redirect_server.uri())?; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache).build(); + let client = registry_client.cached_client().uncached(); + + assert_eq!( + client + .for_host(&redirect_server_url) + .get(redirect_server.uri()) + .send() + .await? + .status(), + 401, + "Requests should fail if credentials are missing" + ); + + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&redirect_server_url) + .get(format!("{url}")) + .send() + .await? + .status(), + 200, + "Requests should succeed if credentials are present" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_root_relative_url() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .and(path_regex("/foo/")) + .respond_with( + ResponseTemplate::new(307).insert_header("Location", "/bar/baz/".to_string()), + ) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/bar/baz/")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + let redirect_server_url = Url::parse(&redirect_server.uri())?.join("foo/")?; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache).build(); + let client = registry_client.cached_client().uncached(); + + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&url) + .get(format!("{url}")) + .send() + .await? + .status(), + 200, + "Requests should succeed for relative URL" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_relative_url() -> Result<(), Error> { + let username = "user"; + let password = "password"; + + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .and(path_regex("/foo/bar/baz/")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/foo/")) + .respond_with( + ResponseTemplate::new(307).insert_header("Location", "bar/baz/".to_string()), + ) + .mount(&redirect_server) + .await; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache).build(); + let client = registry_client.cached_client().uncached(); + + let redirect_server_url = Url::parse(&redirect_server.uri())?.join("foo/")?; + let mut url = redirect_server_url.clone(); + let _ = url.set_username(username); + let _ = url.set_password(Some(password)); + + assert_eq!( + client + .for_host(&url) + .get(format!("{url}")) + .send() + .await? + .status(), + 200, + "Requests should succeed for relative URL" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_redirect_preserve_fragment() -> Result<(), Error> { + let redirect_server = MockServer::start().await; + + // Configure the redirect server to respond with a 307 with a relative URL. + Mock::given(method("GET")) + .respond_with(ResponseTemplate::new(307).insert_header("Location", "/foo".to_string())) + .mount(&redirect_server) + .await; + + Mock::given(method("GET")) + .and(path_regex("/foo")) + .respond_with(ResponseTemplate::new(200)) + .mount(&redirect_server) + .await; + + let cache = Cache::temp()?; + let registry_client = RegistryClientBuilder::new(cache).build(); + let client = registry_client.cached_client().uncached(); + + let mut url = Url::parse(&redirect_server.uri())?; + url.set_fragment(Some("fragment")); + + assert_eq!( + client + .for_host(&url) + .get(format!("{}", url.clone())) + .send() + .await? + .url() + .to_string(), + format!("{}/foo#fragment", redirect_server.uri()), + "Requests should preserve fragment" + ); + + Ok(()) + } + #[test] fn ignore_failing_files() { // 1.7.7 has an invalid requires-python field (double comma), 1.7.8 is valid diff --git a/crates/uv-distribution/src/source/mod.rs b/crates/uv-distribution/src/source/mod.rs index 32338581b..00fee7a2e 100644 --- a/crates/uv-distribution/src/source/mod.rs +++ b/crates/uv-distribution/src/source/mod.rs @@ -1582,7 +1582,7 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { client .unmanaged .uncached_client(resource.git.repository()) - .clone(), + .raw_client(), ) .await { @@ -1863,7 +1863,10 @@ impl<'a, T: BuildContext> SourceDistributionBuilder<'a, T> { .git() .github_fast_path( git, - client.unmanaged.uncached_client(git.repository()).clone(), + client + .unmanaged + .uncached_client(git.repository()) + .raw_client(), ) .await? .is_some() diff --git a/crates/uv-git/src/resolver.rs b/crates/uv-git/src/resolver.rs index 9d9b216b7..2ac85c5f5 100644 --- a/crates/uv-git/src/resolver.rs +++ b/crates/uv-git/src/resolver.rs @@ -52,7 +52,7 @@ impl GitResolver { pub async fn github_fast_path( &self, url: &GitUrl, - client: ClientWithMiddleware, + client: &ClientWithMiddleware, ) -> Result, GitResolverError> { let reference = RepositoryReference::from(url); @@ -112,7 +112,7 @@ impl GitResolver { pub async fn fetch( &self, url: &GitUrl, - client: ClientWithMiddleware, + client: impl Into, disable_ssl: bool, offline: bool, cache: PathBuf, diff --git a/crates/uv-publish/src/lib.rs b/crates/uv-publish/src/lib.rs index ee54cfcb1..48b134e16 100644 --- a/crates/uv-publish/src/lib.rs +++ b/crates/uv-publish/src/lib.rs @@ -12,7 +12,6 @@ use itertools::Itertools; use reqwest::header::AUTHORIZATION; use reqwest::multipart::Part; use reqwest::{Body, Response, StatusCode}; -use reqwest_middleware::RequestBuilder; use reqwest_retry::policies::ExponentialBackoff; use reqwest_retry::{RetryPolicy, Retryable, RetryableStrategy}; use rustc_hash::FxHashSet; @@ -28,8 +27,8 @@ use url::Url; use uv_auth::Credentials; use uv_cache::{Cache, Refresh}; use uv_client::{ - BaseClient, MetadataFormat, OwnedArchive, RegistryClientBuilder, UvRetryableStrategy, - DEFAULT_RETRIES, + BaseClient, MetadataFormat, OwnedArchive, RegistryClientBuilder, RequestBuilder, + UvRetryableStrategy, DEFAULT_RETRIES, }; use uv_configuration::{KeyringProviderType, TrustedPublishing}; use uv_distribution_filename::{DistFilename, SourceDistExtension, SourceDistFilename}; @@ -320,7 +319,9 @@ pub async fn check_trusted_publishing( // We could check for credentials from the keyring or netrc the auth middleware first, but // given that we are in GitHub Actions we check for trusted publishing first. debug!("Running on GitHub Actions without explicit credentials, checking for trusted publishing"); - match trusted_publishing::get_token(registry, client.for_host(registry)).await { + match trusted_publishing::get_token(registry, client.for_host(registry).raw_client()) + .await + { Ok(token) => Ok(TrustedPublishResult::Configured(token)), Err(err) => { // TODO(konsti): It would be useful if we could differentiate between actual errors @@ -354,7 +355,9 @@ pub async fn check_trusted_publishing( ); } - let token = trusted_publishing::get_token(registry, client.for_host(registry)).await?; + let token = + trusted_publishing::get_token(registry, client.for_host(registry).raw_client()) + .await?; Ok(TrustedPublishResult::Configured(token)) } TrustedPublishing::Never => Ok(TrustedPublishResult::Skipped), @@ -738,16 +741,16 @@ async fn form_metadata( /// Build the upload request. /// /// Returns the request and the reporter progress bar id. -async fn build_request( +async fn build_request<'a>( file: &Path, raw_filename: &str, filename: &DistFilename, registry: &Url, - client: &BaseClient, + client: &'a BaseClient, credentials: &Credentials, form_metadata: &[(&'static str, String)], reporter: Arc, -) -> Result<(RequestBuilder, usize), PublishPrepareError> { +) -> Result<(RequestBuilder<'a>, usize), PublishPrepareError> { let mut form = reqwest::multipart::Form::new(); for (key, value) in form_metadata { form = form.text(*key, value.clone()); @@ -959,12 +962,13 @@ mod tests { project_urls: Source, https://github.com/unknown/tqdm "###); + let client = BaseClientBuilder::new().build(); let (request, _) = build_request( &file, raw_filename, &filename, &Url::parse("https://example.org/upload").unwrap(), - &BaseClientBuilder::new().build(), + &client, &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &form_metadata, Arc::new(DummyReporter), @@ -975,7 +979,7 @@ mod tests { insta::with_settings!({ filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], }, { - assert_debug_snapshot!(&request, @r#" + assert_debug_snapshot!(&request.raw_builder(), @r#" RequestBuilder { inner: RequestBuilder { method: POST, @@ -1109,12 +1113,13 @@ mod tests { requires_dist: requests ; extra == 'telegram' "###); + let client = BaseClientBuilder::new().build(); let (request, _) = build_request( &file, raw_filename, &filename, &Url::parse("https://example.org/upload").unwrap(), - &BaseClientBuilder::new().build(), + &client, &Credentials::basic(Some("ferris".to_string()), Some("F3RR!S".to_string())), &form_metadata, Arc::new(DummyReporter), @@ -1125,7 +1130,7 @@ mod tests { insta::with_settings!({ filters => [("boundary=[0-9a-f-]+", "boundary=[...]")], }, { - assert_debug_snapshot!(&request, @r#" + assert_debug_snapshot!(&request.raw_builder(), @r#" RequestBuilder { inner: RequestBuilder { method: POST, diff --git a/crates/uv/tests/it/common/mod.rs b/crates/uv/tests/it/common/mod.rs index e29df1688..4fa0c43f5 100644 --- a/crates/uv/tests/it/common/mod.rs +++ b/crates/uv/tests/it/common/mod.rs @@ -1605,8 +1605,7 @@ pub async fn download_to_disk(url: &str, path: &Path) { .allow_insecure_host(trusted_hosts) .build(); let url: reqwest::Url = url.parse().unwrap(); - let client = client.for_host(&url); - let response = client.request(http::Method::GET, url).send().await.unwrap(); + let response = client.for_host(&url).get(url).send().await.unwrap(); let mut file = tokio::fs::File::create(path).await.unwrap(); let mut stream = response.bytes_stream(); diff --git a/crates/uv/tests/it/edit.rs b/crates/uv/tests/it/edit.rs index 6fc33cf7c..b41c2cdac 100644 --- a/crates/uv/tests/it/edit.rs +++ b/crates/uv/tests/it/edit.rs @@ -3,22 +3,24 @@ #[cfg(feature = "git")] mod conditional_imports { pub(crate) use crate::common::{decode_token, READ_ONLY_GITHUB_TOKEN}; - pub(crate) use assert_cmd::assert::OutputAssertExt; } #[cfg(feature = "git")] use conditional_imports::*; use anyhow::Result; +use assert_cmd::assert::OutputAssertExt; use assert_fs::prelude::*; use indoc::{formatdoc, indoc}; use insta::assert_snapshot; use std::path::Path; +use url::Url; use uv_fs::Simplified; +use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; use uv_static::EnvVars; -use crate::common::{packse_index_url, uv_snapshot, TestContext}; +use crate::common::{packse_index_url, uv_snapshot, venv_bin_path, TestContext}; /// Add a PyPI requirement. #[test] @@ -10748,6 +10750,197 @@ fn add_auth_policy_never_without_credentials() -> Result<()> { Ok(()) } +/// If uv receives a 302 redirect, it should use supplied credentials for the +/// new location. +#[tokio::test] +async fn add_redirect() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! { r#" + [project] + name = "foo" + version = "1.0.0" + requires-python = ">=3.12" + dependencies = [] + "# + })?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + let _ = redirect_url.set_password(Some("heron")); + + uv_snapshot!(context.add().arg("--default-index").arg(redirect_url.as_str()).arg("anyio"), @r" + success: true + exit_code: 0 + ----- stdout ----- + + ----- stderr ----- + Resolved 4 packages in [TIME] + Prepared 3 packages in [TIME] + Installed 3 packages in [TIME] + + anyio==4.3.0 + + idna==3.6 + + sniffio==1.3.1 + " + ); + + context.assert_command("import anyio").success(); + Ok(()) +} + +/// If uv receives a 302 redirect, it should use credentials from the keyring +/// for the new location. +#[tokio::test] +async fn add_redirect_with_keyring() -> Result<()> { + let keyring_context = TestContext::new("3.12"); + + // Install our keyring plugin + keyring_context + .pip_install() + .arg( + keyring_context + .workspace_root + .join("scripts") + .join("packages") + .join("keyring_test_plugin"), + ) + .assert() + .success(); + + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1[^\r\n]*", "[LOCALHOST]")]) + .collect::>(); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! { r#" + [project] + name = "foo" + version = "1.0.0" + requires-python = ">=3.12" + dependencies = [] + + [tool.uv] + keyring-provider = "subprocess" + "#, + })?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + + uv_snapshot!(filters, context.add().arg("--default-index") + .arg(redirect_url.as_str()) + .arg("anyio") + .env(EnvVars::KEYRING_TEST_CREDENTIALS, r#"{"pypi-proxy.fly.dev": {"public": "heron"}}"#) + .env(EnvVars::PATH, venv_bin_path(&keyring_context.venv)), @r" + success: true + exit_code: 0 + ----- stdout ----- + + ----- stderr ----- + Request for public@http://[LOCALHOST] + Request for public@[LOCALHOST] + Request for public@https://pypi-proxy.fly.dev/basic-auth/simple/anyio/ + Request for public@pypi-proxy.fly.dev + Resolved 4 packages in [TIME] + Prepared 3 packages in [TIME] + Installed 3 packages in [TIME] + + anyio==4.3.0 + + idna==3.6 + + sniffio==1.3.1 + " + ); + + context.assert_command("import anyio").success(); + Ok(()) +} + +/// If uv receives a 302 redirect, it should use credentials from netrc +/// for the new location. +#[tokio::test] +async fn add_redirect_with_netrc() -> Result<()> { + let context = TestContext::new("3.12"); + let filters = context + .filters() + .into_iter() + .chain([(r"127\.0\.0\.1[^\r\n]*", "[LOCALHOST]")]) + .collect::>(); + + let netrc = context.temp_dir.child(".netrc"); + netrc.write_str("machine pypi-proxy.fly.dev login public password heron")?; + + let redirect_server = MockServer::start().await; + + Mock::given(method("GET")) + .respond_with(|req: &wiremock::Request| { + let redirect_url = redirect_url_to_pypi_proxy(req); + ResponseTemplate::new(302).insert_header("Location", &redirect_url) + }) + .mount(&redirect_server) + .await; + + let mut redirect_url = Url::parse(&redirect_server.uri())?; + let _ = redirect_url.set_username("public"); + + uv_snapshot!(filters, context.pip_install() + .arg("anyio") + .arg("--index-url") + .arg(redirect_url.as_str()) + .env(EnvVars::NETRC, netrc.to_str().unwrap()) + .arg("--strict"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + + ----- stderr ----- + Resolved 3 packages in [TIME] + Prepared 3 packages in [TIME] + Installed 3 packages in [TIME] + + anyio==4.3.0 + + idna==3.6 + + sniffio==1.3.1 + "### + ); + + context.assert_command("import anyio").success(); + + Ok(()) +} + +fn redirect_url_to_pypi_proxy(req: &wiremock::Request) -> String { + let last_path_segment = req + .url + .path_segments() + .expect("path has segments") + .filter(|segment| !segment.is_empty()) // Filter out empty segments + .next_back() + .expect("path has a package segment"); + format!("https://pypi-proxy.fly.dev/basic-auth/simple/{last_path_segment}/") +} + /// Test the error message when adding a package with multiple existing references in /// `pyproject.toml`. #[test]