Add `Bearer` support to `Credentials` (#12610)

## Summary

I noticed that these only support Basic credentials, but we may want to
allow users to provide Bearer tokens? This PR just generalizes the type.
This commit is contained in:
Charlie Marsh 2025-04-01 17:48:21 -04:00 committed by GitHub
parent 878457b5dd
commit f491aa0f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 125 additions and 69 deletions

View File

@ -245,19 +245,19 @@ mod tests {
#[test] #[test]
fn test_trie() { fn test_trie() {
let credentials1 = Arc::new(Credentials::new( let credentials1 = Arc::new(Credentials::basic(
Some("username1".to_string()), Some("username1".to_string()),
Some("password1".to_string()), Some("password1".to_string()),
)); ));
let credentials2 = Arc::new(Credentials::new( let credentials2 = Arc::new(Credentials::basic(
Some("username2".to_string()), Some("username2".to_string()),
Some("password2".to_string()), Some("password2".to_string()),
)); ));
let credentials3 = Arc::new(Credentials::new( let credentials3 = Arc::new(Credentials::basic(
Some("username3".to_string()), Some("username3".to_string()),
Some("password3".to_string()), Some("password3".to_string()),
)); ));
let credentials4 = Arc::new(Credentials::new( let credentials4 = Arc::new(Credentials::basic(
Some("username4".to_string()), Some("username4".to_string()),
Some("password4".to_string()), Some("password4".to_string()),
)); ));

View File

@ -1,6 +1,7 @@
use base64::prelude::BASE64_STANDARD; use base64::prelude::BASE64_STANDARD;
use base64::read::DecoderReader; use base64::read::DecoderReader;
use base64::write::EncoderWriter; use base64::write::EncoderWriter;
use std::borrow::Cow;
use netrc::Netrc; use netrc::Netrc;
use reqwest::header::HeaderValue; use reqwest::header::HeaderValue;
@ -12,15 +13,21 @@ use url::Url;
use uv_static::EnvVars; use uv_static::EnvVars;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct Credentials { pub enum Credentials {
/// The name of the user for authentication. Basic {
/// The username to use for authentication.
username: Username, username: Username,
/// The password to use for authentication. /// The password to use for authentication.
password: Option<String>, password: Option<String>,
},
Bearer {
/// The token to use for authentication.
token: Vec<u8>,
},
} }
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash, Default)] #[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash, Default)]
pub(crate) struct Username(Option<String>); pub struct Username(Option<String>);
impl Username { impl Username {
/// Create a new username. /// Create a new username.
@ -61,31 +68,54 @@ impl From<Option<String>> for Username {
} }
impl Credentials { impl Credentials {
pub(crate) fn new(username: Option<String>, password: Option<String>) -> Self { /// Create a set of HTTP Basic Authentication credentials.
Self { #[allow(dead_code)]
pub(crate) fn basic(username: Option<String>, password: Option<String>) -> Self {
Self::Basic {
username: Username::new(username), username: Username::new(username),
password, password,
} }
} }
/// Create a set of Bearer Authentication credentials.
#[allow(dead_code)]
pub(crate) fn bearer(token: Vec<u8>) -> Self {
Self::Bearer { token }
}
pub fn username(&self) -> Option<&str> { pub fn username(&self) -> Option<&str> {
self.username.as_deref() match self {
Self::Basic { username, .. } => username.as_deref(),
Self::Bearer { .. } => None,
}
} }
pub(crate) fn to_username(&self) -> Username { pub(crate) fn to_username(&self) -> Username {
self.username.clone() match self {
Self::Basic { username, .. } => username.clone(),
Self::Bearer { .. } => Username::none(),
}
} }
pub(crate) fn as_username(&self) -> &Username { pub(crate) fn as_username(&self) -> Cow<'_, Username> {
&self.username match self {
Self::Basic { username, .. } => Cow::Borrowed(username),
Self::Bearer { .. } => Cow::Owned(Username::none()),
}
} }
pub fn password(&self) -> Option<&str> { pub fn password(&self) -> Option<&str> {
self.password.as_deref() match self {
Self::Basic { password, .. } => password.as_deref(),
Self::Bearer { .. } => None,
}
} }
pub(crate) fn is_empty(&self) -> bool { pub(crate) fn is_empty(&self) -> bool {
self.password.is_none() && self.username.is_none() match self {
Self::Basic { username, password } => username.is_none() && password.is_none(),
Self::Bearer { token } => token.is_empty(),
}
} }
/// Return [`Credentials`] for a [`Url`] from a [`Netrc`] file, if any. /// Return [`Credentials`] for a [`Url`] from a [`Netrc`] file, if any.
@ -103,7 +133,7 @@ impl Credentials {
return None; return None;
}; };
Some(Credentials { Some(Credentials::Basic {
username: Username::new(Some(entry.login.clone())), username: Username::new(Some(entry.login.clone())),
password: Some(entry.password.clone()), password: Some(entry.password.clone()),
}) })
@ -116,7 +146,7 @@ impl Credentials {
if url.username().is_empty() && url.password().is_none() { if url.username().is_empty() && url.password().is_none() {
return None; return None;
} }
Some(Self { Some(Self::Basic {
// Remove percent-encoding from URL credentials // Remove percent-encoding from URL credentials
// See <https://github.com/pypa/pip/blob/06d21db4ff1ab69665c22a88718a4ea9757ca293/src/pip/_internal/utils/misc.py#L497-L499> // See <https://github.com/pypa/pip/blob/06d21db4ff1ab69665c22a88718a4ea9757ca293/src/pip/_internal/utils/misc.py#L497-L499>
username: if url.username().is_empty() { username: if url.username().is_empty() {
@ -149,7 +179,7 @@ impl Credentials {
if username.is_none() && password.is_none() { if username.is_none() && password.is_none() {
None None
} else { } else {
Some(Self::new(username, password)) Some(Self::basic(username, password))
} }
} }
@ -169,14 +199,15 @@ impl Credentials {
/// Parse [`Credentials`] from an authorization header, if any. /// Parse [`Credentials`] from an authorization header, if any.
/// ///
/// Only HTTP Basic Authentication is supported. /// HTTP Basic and Bearer Authentication are both supported.
/// [`None`] will be returned if another authorization scheme is detected. /// [`None`] will be returned if another authorization scheme is detected.
/// ///
/// Panics if the authentication is not conformant to the HTTP Basic Authentication scheme: /// Panics if the authentication is not conformant to the HTTP Basic Authentication scheme:
/// - The contents must be base64 encoded /// - The contents must be base64 encoded
/// - There must be a `:` separator /// - There must be a `:` separator
pub(crate) fn from_header_value(header: &HeaderValue) -> Option<Self> { pub(crate) fn from_header_value(header: &HeaderValue) -> Option<Self> {
let mut value = header.as_bytes().strip_prefix(b"Basic ")?; // Parse a `Basic` authentication header.
if let Some(mut value) = header.as_bytes().strip_prefix(b"Basic ") {
let mut decoder = DecoderReader::new(&mut value, &BASE64_STANDARD); let mut decoder = DecoderReader::new(&mut value, &BASE64_STANDARD);
let mut buf = String::new(); let mut buf = String::new();
decoder decoder
@ -195,13 +226,28 @@ impl Credentials {
} else { } else {
Some(password.to_string()) Some(password.to_string())
}; };
Some(Self::new(username, password)) return Some(Self::Basic {
username: Username::new(username),
password,
});
}
// Parse a `Bearer` authentication header.
if let Some(token) = header.as_bytes().strip_prefix(b"Bearer ") {
return Some(Self::Bearer {
token: token.to_vec(),
});
}
None
} }
/// Create an HTTP Basic Authentication header for the credentials. /// Create an HTTP Basic Authentication header for the credentials.
/// ///
/// Panics if the username or password cannot be base64 encoded. /// Panics if the username or password cannot be base64 encoded.
pub(crate) fn to_header_value(&self) -> HeaderValue { pub(crate) fn to_header_value(&self) -> HeaderValue {
match self {
Self::Basic { .. } => {
// See: <https://github.com/seanmonstar/reqwest/blob/2c11ef000b151c2eebeed2c18a7b81042220c6b0/src/util.rs#L3> // See: <https://github.com/seanmonstar/reqwest/blob/2c11ef000b151c2eebeed2c18a7b81042220c6b0/src/util.rs#L3>
let mut buf = b"Basic ".to_vec(); let mut buf = b"Basic ".to_vec();
{ {
@ -209,13 +255,23 @@ impl Credentials {
write!(encoder, "{}:", self.username().unwrap_or_default()) write!(encoder, "{}:", self.username().unwrap_or_default())
.expect("Write to base64 encoder should succeed"); .expect("Write to base64 encoder should succeed");
if let Some(password) = self.password() { if let Some(password) = self.password() {
write!(encoder, "{password}").expect("Write to base64 encoder should succeed"); write!(encoder, "{password}")
.expect("Write to base64 encoder should succeed");
} }
} }
let mut header = HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue"); let mut header =
HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue");
header.set_sensitive(true); header.set_sensitive(true);
header header
} }
Self::Bearer { token } => {
let mut header = HeaderValue::from_bytes(&[b"Bearer ", token.as_slice()].concat())
.expect("Bearer token is always valid HeaderValue");
header.set_sensitive(true);
header
}
}
}
/// Apply the credentials to the given URL. /// Apply the credentials to the given URL.
/// ///

View File

@ -80,7 +80,7 @@ impl KeyringProvider {
}; };
} }
credentials.map(|(username, password)| Credentials::new(Some(username), Some(password))) credentials.map(|(username, password)| Credentials::basic(Some(username), Some(password)))
} }
#[instrument(skip(self))] #[instrument(skip(self))]
@ -265,7 +265,7 @@ mod tests {
let keyring = KeyringProvider::dummy([(url.host_str().unwrap(), "user", "password")]); let keyring = KeyringProvider::dummy([(url.host_str().unwrap(), "user", "password")]);
assert_eq!( assert_eq!(
keyring.fetch(&url, Some("user")).await, keyring.fetch(&url, Some("user")).await,
Some(Credentials::new( Some(Credentials::basic(
Some("user".to_string()), Some("user".to_string()),
Some("password".to_string()) Some("password".to_string())
)) ))
@ -274,7 +274,7 @@ mod tests {
keyring keyring
.fetch(&url.join("test").unwrap(), Some("user")) .fetch(&url.join("test").unwrap(), Some("user"))
.await, .await,
Some(Credentials::new( Some(Credentials::basic(
Some("user".to_string()), Some("user".to_string()),
Some("password".to_string()) Some("password".to_string())
)) ))
@ -298,21 +298,21 @@ mod tests {
]); ]);
assert_eq!( assert_eq!(
keyring.fetch(&url.join("foo").unwrap(), Some("user")).await, keyring.fetch(&url.join("foo").unwrap(), Some("user")).await,
Some(Credentials::new( Some(Credentials::basic(
Some("user".to_string()), Some("user".to_string()),
Some("password".to_string()) Some("password".to_string())
)) ))
); );
assert_eq!( assert_eq!(
keyring.fetch(&url, Some("user")).await, keyring.fetch(&url, Some("user")).await,
Some(Credentials::new( Some(Credentials::basic(
Some("user".to_string()), Some("user".to_string()),
Some("other-password".to_string()) Some("other-password".to_string())
)) ))
); );
assert_eq!( assert_eq!(
keyring.fetch(&url.join("bar").unwrap(), Some("user")).await, keyring.fetch(&url.join("bar").unwrap(), Some("user")).await,
Some(Credentials::new( Some(Credentials::basic(
Some("user".to_string()), Some("user".to_string()),
Some("other-password".to_string()) Some("other-password".to_string())
)) ))
@ -326,7 +326,7 @@ mod tests {
let credentials = keyring.fetch(&url, Some("user")).await; let credentials = keyring.fetch(&url, Some("user")).await;
assert_eq!( assert_eq!(
credentials, credentials,
Some(Credentials::new( Some(Credentials::basic(
Some("user".to_string()), Some("user".to_string()),
Some("password".to_string()) Some("password".to_string())
)) ))
@ -340,7 +340,7 @@ mod tests {
let credentials = keyring.fetch(&url, None).await; let credentials = keyring.fetch(&url, None).await;
assert_eq!( assert_eq!(
credentials, credentials,
Some(Credentials::new( Some(Credentials::basic(
Some("user".to_string()), Some("user".to_string()),
Some("password".to_string()) Some("password".to_string())
)) ))

View File

@ -397,7 +397,7 @@ impl AuthMiddleware {
None None
} else if let Some(credentials) = self } else if let Some(credentials) = self
.cache() .cache()
.get_url(request.url(), credentials.as_username()) .get_url(request.url(), credentials.as_username().as_ref())
{ {
request = credentials.authenticate(request); request = credentials.authenticate(request);
// Do not insert already-cached credentials // Do not insert already-cached credentials
@ -653,7 +653,7 @@ mod tests {
let cache = CredentialsCache::new(); let cache = CredentialsCache::new();
cache.insert( cache.insert(
&base_url, &base_url,
Arc::new(Credentials::new( Arc::new(Credentials::basic(
Some(username.to_string()), Some(username.to_string()),
Some(password.to_string()), Some(password.to_string()),
)), )),
@ -707,7 +707,7 @@ mod tests {
let cache = CredentialsCache::new(); let cache = CredentialsCache::new();
cache.insert( cache.insert(
&base_url, &base_url,
Arc::new(Credentials::new(Some(username.to_string()), None)), Arc::new(Credentials::basic(Some(username.to_string()), None)),
); );
let client = test_client_builder() let client = test_client_builder()
@ -1097,7 +1097,7 @@ mod tests {
// URL. // URL.
cache.insert( cache.insert(
&base_url, &base_url,
Arc::new(Credentials::new(Some(username.to_string()), None)), Arc::new(Credentials::basic(Some(username.to_string()), None)),
); );
let client = test_client_builder() let client = test_client_builder()
.with(AuthMiddleware::new().with_cache(cache).with_keyring(Some( .with(AuthMiddleware::new().with_cache(cache).with_keyring(Some(
@ -1146,14 +1146,14 @@ mod tests {
// Seed the cache with our credentials // Seed the cache with our credentials
cache.insert( cache.insert(
&base_url_1, &base_url_1,
Arc::new(Credentials::new( Arc::new(Credentials::basic(
Some(username_1.to_string()), Some(username_1.to_string()),
Some(password_1.to_string()), Some(password_1.to_string()),
)), )),
); );
cache.insert( cache.insert(
&base_url_2, &base_url_2,
Arc::new(Credentials::new( Arc::new(Credentials::basic(
Some(username_2.to_string()), Some(username_2.to_string()),
Some(password_2.to_string()), Some(password_2.to_string()),
)), )),
@ -1341,14 +1341,14 @@ mod tests {
// Seed the cache with our credentials // Seed the cache with our credentials
cache.insert( cache.insert(
&base_url_1, &base_url_1,
Arc::new(Credentials::new( Arc::new(Credentials::basic(
Some(username_1.to_string()), Some(username_1.to_string()),
Some(password_1.to_string()), Some(password_1.to_string()),
)), )),
); );
cache.insert( cache.insert(
&base_url_2, &base_url_2,
Arc::new(Credentials::new( Arc::new(Credentials::basic(
Some(username_2.to_string()), Some(username_2.to_string()),
Some(password_2.to_string()), Some(password_2.to_string()),
)), )),