diff --git a/Cargo.lock b/Cargo.lock index 991f8a208..6f033fa40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -811,7 +811,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" dependencies = [ "lazy_static", - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -1855,7 +1855,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.0", "system-configuration", "tokio", "tower-service", @@ -2085,6 +2085,15 @@ dependencies = [ "serde", ] +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + [[package]] name = "is-terminal" version = "0.4.16" @@ -2093,7 +2102,17 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", +] + +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", ] [[package]] @@ -2153,7 +2172,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2673,6 +2692,17 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "open" +version = "5.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2483562e62ea94312f3576a7aca397306df7990b8d89033e18766744377ef95" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "openssl-probe" version = "0.1.6" @@ -3101,7 +3131,7 @@ dependencies = [ "once_cell", "socket2 0.5.10", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3553,7 +3583,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4214,7 +4244,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix 1.0.8", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4985,6 +5015,7 @@ dependencies = [ "itertools 0.14.0", "miette", "nix 0.30.1", + "open", "owo-colors", "petgraph", "predicates", @@ -5012,6 +5043,7 @@ dependencies = [ "tracing-tree", "unicode-width 0.2.1", "url", + "uuid", "uv-auth", "uv-bin-install", "uv-build-backend", @@ -5075,12 +5107,15 @@ name = "uv-auth" version = "0.0.1" dependencies = [ "anyhow", + "arcstr", "async-trait", "base64 0.22.1", + "etcetera", "fs-err", "futures", "http", "insta", + "jiff", "percent-encoding", "reqwest", "reqwest-middleware", @@ -5088,6 +5123,7 @@ dependencies = [ "rustc-hash", "schemars", "serde", + "serde_json", "tempfile", "test-log", "thiserror 2.0.16", @@ -5095,6 +5131,7 @@ dependencies = [ "toml", "tracing", "url", + "uv-cache-key", "uv-fs", "uv-keyring", "uv-once-map", @@ -6757,7 +6794,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 96c0b38d1..bea4616a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -135,6 +135,7 @@ miette = { version = "7.2.0", features = ["fancy-no-backtrace"] } nanoid = { version = "0.4.0" } nix = { version = "0.30.0", features = ["signal"] } once_cell = { version = "1.20.2" } +open = { version = "5.3.2" } owo-colors = { version = "4.1.0" } path-slash = { version = "0.2.1" } pathdiff = { version = "0.2.1" } @@ -188,6 +189,7 @@ tracing-tree = { version = "0.4.0" } unicode-width = { version = "0.2.0" } unscanny = { version = "0.1.0" } url = { version = "2.5.2", features = ["serde"] } +uuid = { version = "1.16.0" } version-ranges = { git = "https://github.com/astral-sh/pubgrub", rev = "06ec5a5f59ffaeb6cf5079c6cb184467da06c9db" } walkdir = { version = "2.5.0" } which = { version = "8.0.0", features = ["regex"] } diff --git a/crates/uv-auth/Cargo.toml b/crates/uv-auth/Cargo.toml index ba194e801..17d37057f 100644 --- a/crates/uv-auth/Cargo.toml +++ b/crates/uv-auth/Cargo.toml @@ -10,22 +10,26 @@ doctest = false workspace = true [dependencies] +uv-cache-key = { workspace = true } uv-fs = { workspace = true } uv-keyring = { workspace = true, features = ["apple-native", "secret-service", "windows-native"] } uv-once-map = { workspace = true } uv-preview = { workspace = true } uv-redacted = { workspace = true } uv-small-str = { workspace = true } -uv-static = { workspace = true } uv-state = { workspace = true } +uv-static = { workspace = true } uv-warnings = { workspace = true } anyhow = { workspace = true } +arcstr = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } -fs-err = { workspace = true } +etcetera = { workspace = true } +fs-err = { workspace = true, features = ["tokio"] } futures = { workspace = true } http = { workspace = true } +jiff = { workspace = true } percent-encoding = { workspace = true } reqwest = { workspace = true } reqwest-middleware = { workspace = true } @@ -33,6 +37,7 @@ rust-netrc = { workspace = true } rustc-hash = { workspace = true } schemars = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } toml = { workspace = true } diff --git a/crates/uv-auth/src/access_token.rs b/crates/uv-auth/src/access_token.rs new file mode 100644 index 000000000..76f7d3ef0 --- /dev/null +++ b/crates/uv-auth/src/access_token.rs @@ -0,0 +1,34 @@ +/// An encoded JWT access token. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +#[serde(transparent)] +pub struct AccessToken(String); + +impl AccessToken { + /// Return the [`AccessToken`] as a vector of bytes. + pub fn into_bytes(self) -> Vec { + self.0.into_bytes() + } + + /// Return the [`AccessToken`] as a string slice. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl From for AccessToken { + fn from(value: String) -> Self { + Self(value) + } +} + +impl AsRef<[u8]> for AccessToken { + fn as_ref(&self) -> &[u8] { + self.0.as_bytes() + } +} + +impl std::fmt::Display for AccessToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/crates/uv-auth/src/lib.rs b/crates/uv-auth/src/lib.rs index 962c6fcc9..9f878b500 100644 --- a/crates/uv-auth/src/lib.rs +++ b/crates/uv-auth/src/lib.rs @@ -2,25 +2,30 @@ use std::sync::{Arc, LazyLock}; use tracing::trace; +use uv_redacted::DisplaySafeUrl; + +pub use access_token::AccessToken; use cache::CredentialsCache; pub use credentials::{Credentials, Username}; pub use index::{AuthPolicy, Index, Indexes}; pub use keyring::KeyringProvider; pub use middleware::AuthMiddleware; -use realm::Realm; +pub use pyx::{DEFAULT_TOLERANCE_SECS, PyxOAuthTokens, PyxTokenStore, PyxTokens, TokenStoreError}; +pub use realm::Realm; pub use service::{Service, ServiceParseError}; -pub use store::{AuthScheme, TextCredentialStore, TomlCredentialError}; -use uv_redacted::DisplaySafeUrl; +pub use store::{AuthBackend, AuthScheme, TextCredentialStore, TomlCredentialError}; +mod access_token; mod cache; mod credentials; mod index; mod keyring; mod middleware; mod providers; +mod pyx; mod realm; mod service; -pub mod store; +mod store; // TODO(zanieb): Consider passing a cache explicitly throughout diff --git a/crates/uv-auth/src/middleware.rs b/crates/uv-auth/src/middleware.rs index 841998397..6136d9cff 100644 --- a/crates/uv-auth/src/middleware.rs +++ b/crates/uv-auth/src/middleware.rs @@ -4,20 +4,24 @@ use anyhow::{anyhow, format_err}; use http::{Extensions, StatusCode}; use netrc::Netrc; use reqwest::{Request, Response}; -use reqwest_middleware::{Error, Middleware, Next}; +use reqwest_middleware::{ClientWithMiddleware, Error, Middleware, Next}; +use tokio::sync::Mutex; use tracing::{debug, trace, warn}; use uv_preview::{Preview, PreviewFeatures}; use uv_redacted::DisplaySafeUrl; +use uv_warnings::owo_colors::OwoColorize; use crate::providers::HuggingFaceProvider; +use crate::pyx::{DEFAULT_TOLERANCE_SECS, PyxTokenStore}; use crate::{ - CREDENTIALS_CACHE, CredentialsCache, KeyringProvider, + AccessToken, CREDENTIALS_CACHE, CredentialsCache, KeyringProvider, cache::FetchUrl, credentials::{Credentials, Username}, index::{AuthPolicy, Indexes}, realm::Realm, }; + use crate::{TextCredentialStore, TomlCredentialError}; /// Strategy for loading netrc files. @@ -105,6 +109,15 @@ impl TextStoreMode { } } +#[derive(Debug, Clone)] +enum TokenState { + /// The token state has not yet been initialized from the store. + Uninitialized, + /// The token state has been initialized, and the store either returned tokens or `None` if + /// the user has not yet authenticated. + Initialized(Option), +} + /// A middleware that adds basic authentication to requests. /// /// Uses a cache to propagate credentials from previously seen requests and @@ -119,6 +132,12 @@ pub struct AuthMiddleware { /// Set all endpoints as needing authentication. We never try to send an /// unauthenticated request, avoiding cloning an uncloneable request. only_authenticated: bool, + /// The base client to use for requests within the middleware. + base_client: Option, + /// The pyx token store to use for persistent credentials. + pyx_token_store: Option, + /// Tokens to use for persistent credentials. + pyx_token_state: Mutex, preview: Preview, } @@ -131,6 +150,9 @@ impl AuthMiddleware { cache: None, indexes: Indexes::new(), only_authenticated: false, + base_client: None, + pyx_token_store: None, + pyx_token_state: Mutex::new(TokenState::Uninitialized), preview: Preview::default(), } } @@ -197,6 +219,20 @@ impl AuthMiddleware { self } + /// Configure the [`ClientWithMiddleware`] to use for requests within the middleware. + #[must_use] + pub fn with_base_client(mut self, client: ClientWithMiddleware) -> Self { + self.base_client = Some(client); + self + } + + /// Configure the [`PyxTokenStore`] to use for persistent credentials. + #[must_use] + pub fn with_pyx_token_store(mut self, token_store: PyxTokenStore) -> Self { + self.pyx_token_store = Some(token_store); + self + } + /// Get the configured authentication store. /// /// If not set, the global store is used. @@ -309,9 +345,20 @@ impl Middleware for AuthMiddleware { .as_ref() .is_some_and(|credentials| credentials.username().is_some()); - let retry_unauthenticated = - !self.only_authenticated && !matches!(auth_policy, AuthPolicy::Always); - let (mut retry_request, response) = if retry_unauthenticated { + // Determine whether this is a "known" URL. + let is_known_url = self + .pyx_token_store + .as_ref() + .is_some_and(|token_store| token_store.is_known_url(request.url())); + + let must_authenticate = self.only_authenticated + || match auth_policy { + AuthPolicy::Auto => is_known_url, + AuthPolicy::Always => true, + AuthPolicy::Never => false, + }; + + let (mut retry_request, response) = if !must_authenticate { let url = tracing_url(&request, credentials.as_deref()); if credentials.is_none() { trace!("Attempting unauthenticated request for {url}"); @@ -419,9 +466,16 @@ impl Middleware for AuthMiddleware { if let Some(response) = response { Ok(response) } else { - Err(Error::Middleware(format_err!( - "Missing credentials for {url}" - ))) + if is_known_url { + Err(Error::Middleware(format_err!( + "Run `{}` to authenticate the uv CLI", + "uv auth login pyx.dev".green() + ))) + } else { + Err(Error::Middleware(format_err!( + "Missing credentials for {url}" + ))) + } } } } @@ -589,6 +643,46 @@ impl AuthMiddleware { return Some(credentials); } + // If this is a known URL, authenticate it via the token store. + if let Some(base_client) = self.base_client.as_ref() { + if let Some(token_store) = self.pyx_token_store.as_ref() { + if token_store.is_known_url(url) { + let mut token_state = self.pyx_token_state.lock().await; + + // If the token store is uninitialized, initialize it. + let token = match *token_state { + TokenState::Uninitialized => { + trace!("Initializing token store for {url}"); + let generated = match token_store + .access_token(base_client, DEFAULT_TOLERANCE_SECS) + .await + { + Ok(Some(token)) => Some(token), + Ok(None) => None, + Err(err) => { + warn!("Failed to generate access tokens: {err}"); + None + } + }; + *token_state = TokenState::Initialized(generated.clone()); + generated + } + TokenState::Initialized(ref tokens) => tokens.clone(), + }; + + let credentials = token.map(|token| { + trace!("Using credentials from token store for {url}"); + Arc::new(Credentials::from(token)) + }); + + // Register the fetch for this key + self.cache().fetches.done(key.clone(), credentials.clone()); + + return credentials; + } + } + } + // Netrc support based on: . let credentials = if let Some(credentials) = self.netrc.get().and_then(|netrc| { debug!("Checking netrc for credentials for {url}"); diff --git a/crates/uv-auth/src/pyx.rs b/crates/uv-auth/src/pyx.rs new file mode 100644 index 000000000..5b2602346 --- /dev/null +++ b/crates/uv-auth/src/pyx.rs @@ -0,0 +1,682 @@ +use std::io; +use std::path::PathBuf; +use std::time::Duration; + +use base64::Engine; +use base64::prelude::BASE64_URL_SAFE_NO_PAD; +use etcetera::BaseStrategy; +use reqwest_middleware::ClientWithMiddleware; +use tracing::debug; +use url::Url; + +use uv_cache_key::CanonicalUrl; +use uv_redacted::DisplaySafeUrl; +use uv_small_str::SmallString; +use uv_state::{StateBucket, StateStore}; +use uv_static::EnvVars; + +use crate::{AccessToken, Credentials, Realm}; + +/// Retrieve the pyx API key from the environment variable, or return `None`. +fn read_pyx_api_key() -> Option { + std::env::var(EnvVars::PYX_API_KEY) + .ok() + .or_else(|| std::env::var(EnvVars::UV_API_KEY).ok()) +} + +/// Retrieve the pyx authentication token (JWT) from the environment variable, or return `None`. +fn read_pyx_auth_token() -> Option { + std::env::var(EnvVars::PYX_AUTH_TOKEN) + .ok() + .or_else(|| std::env::var(EnvVars::UV_AUTH_TOKEN).ok()) + .map(AccessToken::from) +} + +/// An access token with an accompanying refresh token. +/// +/// Refresh tokens are single-use tokens that can be exchanged for a renewed access token +/// and a new refresh token. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct PyxOAuthTokens { + pub access_token: AccessToken, + pub refresh_token: String, +} + +/// An access token with an accompanying API key. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct PyxApiKeyTokens { + pub access_token: AccessToken, + pub api_key: String, +} + +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub enum PyxTokens { + /// An access token with an accompanying refresh token. + /// + /// Refresh tokens are single-use tokens that can be exchanged for a renewed access token + /// and a new refresh token. + OAuth(PyxOAuthTokens), + /// An access token with an accompanying API key. + /// + /// API keys are long-lived tokens that can be exchanged for an access token. + ApiKey(PyxApiKeyTokens), +} + +impl From for AccessToken { + fn from(tokens: PyxTokens) -> Self { + match tokens { + PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token, + PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token, + } + } +} + +impl From for Credentials { + fn from(tokens: PyxTokens) -> Self { + let access_token = match tokens { + PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token, + PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token, + }; + Self::from(access_token) + } +} + +impl From for Credentials { + fn from(access_token: AccessToken) -> Self { + Self::Bearer { + token: access_token.into_bytes(), + } + } +} + +/// The default tolerance for the access token expiration. +pub const DEFAULT_TOLERANCE_SECS: u64 = 60 * 5; + +/// The root directory for the pyx token store. +fn root_dir(api: &DisplaySafeUrl) -> Result { + // Store credentials in a subdirectory based on the API URL. + let digest = uv_cache_key::cache_digest(&CanonicalUrl::new(api)); + + // If the user explicitly set `PYX_CREDENTIALS_DIR`, use that. + if let Some(tool_dir) = std::env::var_os(EnvVars::PYX_CREDENTIALS_DIR) { + return std::path::absolute(tool_dir).map(|dir| dir.join(&digest)); + } + + // If the user has pyx credentials in their uv credentials directory, read them for + // backwards compatibility. + let credentials_dir = if let Some(tool_dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR) { + std::path::absolute(tool_dir)? + } else { + StateStore::from_settings(None)?.bucket(StateBucket::Credentials) + }; + let credentials_dir = credentials_dir.join(&digest); + if credentials_dir.exists() { + return Ok(credentials_dir); + } + + // Otherwise, use (e.g.) `~/.local/share/pyx`. + let Ok(xdg) = etcetera::base_strategy::choose_base_strategy() else { + return Err(io::Error::new( + io::ErrorKind::NotFound, + "Could not determine user data directory", + )); + }; + + Ok(xdg.data_dir().join("pyx").join("credentials").join(&digest)) +} + +#[derive(Debug, Clone)] +pub struct PyxTokenStore { + /// The root directory for the token store (e.g., `/Users/ferris/.local/share/pyx/credentials/3859a629b26fda96`). + root: PathBuf, + /// The API URL for the token store (e.g., `https://api.pyx.dev`). + api: DisplaySafeUrl, + /// The CDN domain for the token store (e.g., `astralhosted.com`). + cdn: SmallString, +} + +impl PyxTokenStore { + /// Create a new [`PyxTokenStore`] from settings. + pub fn from_settings() -> Result { + // Read the API URL and CDN domain from the environment variables, or fallback to the + // defaults. + let api = if let Ok(api_url) = std::env::var(EnvVars::PYX_API_URL) { + DisplaySafeUrl::parse(&api_url) + } else { + DisplaySafeUrl::parse("https://api.pyx.dev") + }?; + let cdn = std::env::var(EnvVars::PYX_CDN_DOMAIN) + .ok() + .map(SmallString::from) + .unwrap_or_else(|| SmallString::from(arcstr::literal!("astralhosted.com"))); + + // Determine the root directory for the token store. + let root = root_dir(&api)?; + + Ok(Self { root, api, cdn }) + } + + /// Return the API URL for the token store. + pub fn api(&self) -> &DisplaySafeUrl { + &self.api + } + + /// Get or initialize an [`AccessToken`] from the store. + /// + /// If an access token is set in the environment, it will be returned as-is. + /// + /// If an access token is present on-disk, it will be returned (and refreshed, if necessary). + /// + /// If no access token is found, but an API key is present, the API key will be used to + /// bootstrap an access token. + pub async fn access_token( + &self, + client: &ClientWithMiddleware, + tolerance_secs: u64, + ) -> Result, TokenStoreError> { + // If the access token is already set in the environment, return it. + if let Some(access_token) = read_pyx_auth_token() { + return Ok(Some(access_token)); + } + + // Initialize the tokens from the store. + let tokens = self.init(client, tolerance_secs).await?; + + // Extract the access token from the OAuth tokens or API key. + Ok(tokens.map(AccessToken::from)) + } + + /// Initialize the [`PyxTokens`] from the store. + /// + /// If an access token is already present, it will be returned (and refreshed, if necessary). + /// + /// If no access token is found, but an API key is present, the API key will be used to + /// bootstrap an access token. + pub async fn init( + &self, + client: &ClientWithMiddleware, + tolerance_secs: u64, + ) -> Result, TokenStoreError> { + match self.read().await? { + Some(tokens) => { + // Refresh the tokens if they are expired. + let tokens = self.refresh(tokens, client, tolerance_secs).await?; + Ok(Some(tokens)) + } + None => { + // If no tokens are present, bootstrap them from an API key. + self.bootstrap(client).await + } + } + } + + /// Write the tokens to the store. + pub async fn write(&self, tokens: &PyxTokens) -> Result<(), TokenStoreError> { + fs_err::tokio::create_dir_all(&self.root).await?; + match tokens { + PyxTokens::OAuth(tokens) => { + // Write OAuth tokens to a generic `tokens.json` file. + fs_err::tokio::write(self.root.join("tokens.json"), serde_json::to_vec(tokens)?) + .await?; + } + PyxTokens::ApiKey(tokens) => { + // Write API key tokens to a file based on the API key. + let digest = uv_cache_key::cache_digest(&tokens.api_key); + fs_err::tokio::write( + self.root.join(format!("{digest}.json")), + &tokens.access_token, + ) + .await?; + } + } + Ok(()) + } + + /// Returns `true` if the user appears to have credentials (which may be invalid). + pub fn has_credentials(&self) -> bool { + read_pyx_auth_token().is_some() + || read_pyx_api_key().is_some() + || self.root.join("tokens.json").is_file() + } + + /// Read the tokens from the store. + pub async fn read(&self) -> Result, TokenStoreError> { + // Retrieve the API URL from the environment variable, or error if unset. + if let Some(api_key) = read_pyx_api_key() { + // Read the API key tokens from a file based on the API key. + let digest = uv_cache_key::cache_digest(&api_key); + match fs_err::tokio::read(self.root.join(format!("{digest}.json"))).await { + Ok(data) => { + let access_token = + AccessToken::from(String::from_utf8(data).expect("Invalid UTF-8")); + Ok(Some(PyxTokens::ApiKey(PyxApiKeyTokens { + access_token, + api_key, + }))) + } + Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None), + Err(err) => Err(err.into()), + } + } else { + match fs_err::tokio::read(self.root.join("tokens.json")).await { + Ok(data) => { + let tokens: PyxOAuthTokens = serde_json::from_slice(&data)?; + Ok(Some(PyxTokens::OAuth(tokens))) + } + Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None), + Err(err) => Err(err.into()), + } + } + } + + /// Remove the tokens from the store. + pub async fn delete(&self) -> Result<(), io::Error> { + fs_err::tokio::remove_dir_all(&self.root).await?; + Ok(()) + } + + /// Bootstrap the tokens from the store. + async fn bootstrap( + &self, + client: &ClientWithMiddleware, + ) -> Result, TokenStoreError> { + #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] + struct Payload { + access_token: AccessToken, + } + + // Retrieve the API key from the environment variable, if set. + let Some(api_key) = read_pyx_api_key() else { + return Ok(None); + }; + + debug!("Bootstrapping access token from an API key"); + + // Parse the API URL. + let mut url = self.api.clone(); + url.set_path("auth/cli/access-token"); + + let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url)); + request.headers_mut().insert( + "Authorization", + reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?, + ); + + let response = client.execute(request).await?; + let Payload { access_token } = response.error_for_status()?.json::().await?; + let tokens = PyxTokens::ApiKey(PyxApiKeyTokens { + access_token, + api_key, + }); + + // Write the tokens to disk. + self.write(&tokens).await?; + + Ok(Some(tokens)) + } + + /// Refresh the tokens in the store, if they are expired. + /// + /// In theory, we should _also_ refresh if we hit a 401; but for now, we only refresh ahead of + /// time. + async fn refresh( + &self, + tokens: PyxTokens, + client: &ClientWithMiddleware, + tolerance_secs: u64, + ) -> Result { + // Decode the access token. + let jwt = Jwt::decode(match &tokens { + PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token.as_str(), + PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token.as_str(), + })?; + + // If the access token is expired, refresh it. + let is_up_to_date = match jwt.exp { + None => { + debug!("Access token has no expiration; refreshing..."); + false + } + Some(..) if tolerance_secs == 0 => { + debug!("Refreshing access token due to zero tolerance..."); + false + } + Some(jwt) => { + let exp = jiff::Timestamp::from_second(jwt)?; + let now = jiff::Timestamp::now(); + if exp < now { + debug!("Access token is expired (`{exp}`); refreshing..."); + false + } else if exp < now + Duration::from_secs(tolerance_secs) { + debug!( + "Access token will expire within the tolerance (`{exp}`); refreshing..." + ); + false + } else { + debug!("Access token is up-to-date (`{exp}`)"); + true + } + } + }; + + if is_up_to_date { + return Ok(tokens); + } + + let tokens = match tokens { + PyxTokens::OAuth(PyxOAuthTokens { refresh_token, .. }) => { + // Parse the API URL. + let mut url = self.api.clone(); + url.set_path("auth/cli/refresh"); + + let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url)); + let body = serde_json::json!({ + "refresh_token": refresh_token + }); + *request.body_mut() = Some(body.to_string().into()); + + let response = client.execute(request).await?; + let tokens = response + .error_for_status()? + .json::() + .await?; + PyxTokens::OAuth(tokens) + } + PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => { + #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] + struct Payload { + access_token: AccessToken, + } + + // Parse the API URL. + let mut url = self.api.clone(); + url.set_path("auth/cli/access-token"); + + let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url)); + request.headers_mut().insert( + "Authorization", + reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?, + ); + + let response = client.execute(request).await?; + let Payload { access_token } = + response.error_for_status()?.json::().await?; + PyxTokens::ApiKey(PyxApiKeyTokens { + access_token, + api_key, + }) + } + }; + + // Write the new tokens to disk. + self.write(&tokens).await?; + Ok(tokens) + } + + /// Returns `true` if the given URL is "known" to this token store (i.e., should be + /// authenticated using the store's tokens). + pub fn is_known_url(&self, url: &Url) -> bool { + is_known_url(url, &self.api, &self.cdn) + } + + /// Returns `true` if the URL is on a "known" domain (i.e., the same domain as the API or CDN). + /// + /// Like [`is_known_url`](Self::is_known_url), but also returns `true` if the API is on the + /// subdomain of the URL (e.g., if the API is `api.pyx.dev` and the URL is `pyx.dev`). + pub fn is_known_domain(&self, url: &Url) -> bool { + is_known_domain(url, &self.api, &self.cdn) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum TokenStoreError { + #[error(transparent)] + Url(#[from] url::ParseError), + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Serialization(#[from] serde_json::Error), + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + #[error(transparent)] + ReqwestMiddleware(#[from] reqwest_middleware::Error), + #[error(transparent)] + InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue), + #[error(transparent)] + Jiff(#[from] jiff::Error), + #[error(transparent)] + Jwt(#[from] JwtError), +} + +impl TokenStoreError { + /// Returns `true` if the error is a 401 (Unauthorized) error. + pub fn is_unauthorized(&self) -> bool { + match self { + Self::Reqwest(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED), + Self::ReqwestMiddleware(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED), + _ => false, + } + } +} + +/// The payload of the JWT. +#[derive(Debug, serde::Deserialize)] +struct Jwt { + exp: Option, +} + +impl Jwt { + /// Decode the JWT from the access token. + fn decode(access_token: &str) -> Result { + let mut token_segments = access_token.splitn(3, '.'); + + let _header = token_segments.next().ok_or(JwtError::MissingHeader)?; + let payload = token_segments.next().ok_or(JwtError::MissingPayload)?; + let _signature = token_segments.next().ok_or(JwtError::MissingSignature)?; + if token_segments.next().is_some() { + return Err(JwtError::TooManySegments); + } + + let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload)?; + + let jwt = serde_json::from_slice::(&decoded)?; + Ok(jwt) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum JwtError { + #[error("JWT is missing a header")] + MissingHeader, + #[error("JWT is missing a payload")] + MissingPayload, + #[error("JWT is missing a signature")] + MissingSignature, + #[error("JWT has too many segments")] + TooManySegments, + #[error(transparent)] + Base64(#[from] base64::DecodeError), + #[error(transparent)] + Serde(#[from] serde_json::Error), +} + +fn is_known_url(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool { + // Determine whether the URL matches the API realm. + if Realm::from(url) == Realm::from(&**api) { + return true; + } + + // Determine whether the URL matches the CDN domain (or a subdomain of it). + // + // For example, if URL is on `files.astralhosted.com` and the CDN domain is + // `astralhosted.com`, consider it known. + if matches_domain(url, cdn) { + return true; + } + + false +} + +fn is_known_domain(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool { + // Determine whether the URL matches the API domain. + if let Some(domain) = url.domain() { + if matches_domain(api, domain) { + return true; + } + } + is_known_url(url, api, cdn) +} + +/// Returns `true` if the target URL is on the given domain. +fn matches_domain(url: &Url, domain: &str) -> bool { + url.domain().is_some_and(|subdomain| { + subdomain == domain + || subdomain + .strip_suffix(domain) + .is_some_and(|prefix| prefix.ends_with('.')) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_known_url() { + let api_url = DisplaySafeUrl::from(Url::parse("https://api.pyx.dev").unwrap()); + let cdn_domain = "astralhosted.com"; + + // Same realm as API. + assert!(is_known_url( + &Url::parse("https://api.pyx.dev/simple/").unwrap(), + &api_url, + cdn_domain + )); + + // Different path on same API domain + assert!(is_known_url( + &Url::parse("https://api.pyx.dev/v1/").unwrap(), + &api_url, + cdn_domain + )); + + // CDN domain. + assert!(is_known_url( + &Url::parse("https://astralhosted.com/packages/").unwrap(), + &api_url, + cdn_domain + )); + + // CDN subdomain. + assert!(is_known_url( + &Url::parse("https://files.astralhosted.com/packages/").unwrap(), + &api_url, + cdn_domain + )); + + // Unknown domain. + assert!(!is_known_url( + &Url::parse("https://pypi.org/simple/").unwrap(), + &api_url, + cdn_domain + )); + + // Similar but not matching domain. + assert!(!is_known_url( + &Url::parse("https://badastralhosted.com/packages/").unwrap(), + &api_url, + cdn_domain + )); + } + + #[test] + fn test_is_known_domain() { + let api_url = DisplaySafeUrl::from(Url::parse("https://api.pyx.dev").unwrap()); + let cdn_domain = "astralhosted.com"; + + // Same realm as API. + assert!(is_known_domain( + &Url::parse("https://api.pyx.dev/simple/").unwrap(), + &api_url, + cdn_domain + )); + + // API super-domain. + assert!(is_known_domain( + &Url::parse("https://pyx.dev").unwrap(), + &api_url, + cdn_domain + )); + + // API subdomain. + assert!(!is_known_domain( + &Url::parse("https://foo.api.pyx.dev").unwrap(), + &api_url, + cdn_domain + )); + + // Different subdomain. + assert!(!is_known_domain( + &Url::parse("https://beta.pyx.dev/").unwrap(), + &api_url, + cdn_domain + )); + + // CDN domain. + assert!(is_known_domain( + &Url::parse("https://astralhosted.com/packages/").unwrap(), + &api_url, + cdn_domain + )); + + // CDN subdomain. + assert!(is_known_domain( + &Url::parse("https://files.astralhosted.com/packages/").unwrap(), + &api_url, + cdn_domain + )); + + // Unknown domain. + assert!(!is_known_domain( + &Url::parse("https://pypi.org/simple/").unwrap(), + &api_url, + cdn_domain + )); + + // Different TLD. + assert!(!is_known_domain( + &Url::parse("https://pyx.com/").unwrap(), + &api_url, + cdn_domain + )); + } + + #[test] + fn test_matches_domain() { + assert!(matches_domain( + &Url::parse("https://example.com").unwrap(), + "example.com" + )); + assert!(matches_domain( + &Url::parse("https://foo.example.com").unwrap(), + "example.com" + )); + assert!(matches_domain( + &Url::parse("https://bar.foo.example.com").unwrap(), + "example.com" + )); + + assert!(!matches_domain( + &Url::parse("https://example.com").unwrap(), + "other.com" + )); + assert!(!matches_domain( + &Url::parse("https://example.org").unwrap(), + "example.com" + )); + assert!(!matches_domain( + &Url::parse("https://badexample.com").unwrap(), + "example.com" + )); + } +} diff --git a/crates/uv-auth/src/realm.rs b/crates/uv-auth/src/realm.rs index 03b3c8fcf..b2abd1266 100644 --- a/crates/uv-auth/src/realm.rs +++ b/crates/uv-auth/src/realm.rs @@ -23,7 +23,7 @@ use uv_small_str::SmallString; // However, `url` (and therefore `reqwest`) sets the `port` to `None` if it matches the default port // so we do not need any special handling here. #[derive(Debug, Clone)] -pub(crate) struct Realm { +pub struct Realm { scheme: SmallString, host: Option, port: Option, diff --git a/crates/uv-auth/src/store.rs b/crates/uv-auth/src/store.rs index a761bbaf1..5964b3cc8 100644 --- a/crates/uv-auth/src/store.rs +++ b/crates/uv-auth/src/store.rs @@ -19,6 +19,7 @@ use crate::service::Service; use crate::{Credentials, KeyringProvider}; /// The storage backend to use in `uv auth` commands. +#[derive(Debug)] pub enum AuthBackend { // TODO(zanieb): Right now, we're using a keyring provider for the system store but that's just // where the native implementation is living at the moment. We should consider refactoring these @@ -104,11 +105,11 @@ pub enum BearerAuthError { /// A single credential entry in a TOML credentials file. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(try_from = "TomlCredentialWire", into = "TomlCredentialWire")] -pub struct TomlCredential { +struct TomlCredential { /// The service URL for this credential. - pub service: Service, + service: Service, /// The credentials for this entry. - pub credentials: Credentials, + credentials: Credentials, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -380,11 +381,13 @@ impl TextCredentialStore { #[cfg(test)] mod tests { - use super::*; use std::io::Write; use std::str::FromStr; + use tempfile::NamedTempFile; + use super::*; + #[test] fn test_toml_serialization() { let credentials = TomlCredentials { diff --git a/crates/uv-build-frontend/src/lib.rs b/crates/uv-build-frontend/src/lib.rs index 2941b8f49..524cb96c0 100644 --- a/crates/uv-build-frontend/src/lib.rs +++ b/crates/uv-build-frontend/src/lib.rs @@ -1164,6 +1164,11 @@ impl PythonRunner { // tools, which might mess with wrappers trying to parse their // output. .env(EnvVars::PYTHONIOENCODING, "utf-8:backslashreplace") + // Remove potentially-sensitive environment variables. + .env_remove(EnvVars::PYX_API_KEY) + .env_remove(EnvVars::UV_API_KEY) + .env_remove(EnvVars::PYX_AUTH_TOKEN) + .env_remove(EnvVars::UV_AUTH_TOKEN) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) .spawn() diff --git a/crates/uv-client/src/base_client.rs b/crates/uv-client/src/base_client.rs index d4efe49cb..20ef8f7b6 100644 --- a/crates/uv-client/src/base_client.rs +++ b/crates/uv-client/src/base_client.rs @@ -28,7 +28,7 @@ use tracing::{debug, trace}; use url::ParseError; use url::Url; -use uv_auth::{AuthMiddleware, Credentials, Indexes}; +use uv_auth::{AuthMiddleware, Credentials, Indexes, PyxTokenStore}; use uv_configuration::{KeyringProviderType, TrustedHost}; use uv_fs::Simplified; use uv_pep508::MarkerEnvironment; @@ -472,6 +472,30 @@ impl<'a> BaseClientBuilder<'a> { fn apply_middleware(&self, client: Client) -> ClientWithMiddleware { match self.connectivity { Connectivity::Online => { + // Create a base client to using in the authentication middleware. + let base_client = { + let mut client = reqwest_middleware::ClientBuilder::new(client.clone()); + + // Avoid uncloneable errors with a streaming body during publish. + if self.retries > 0 { + // Initialize the retry strategy. + let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy( + self.retry_policy(), + UvRetryableStrategy, + ); + client = client.with(retry_strategy); + } + + // When supplied, add the extra middleware. + if let Some(extra_middleware) = &self.extra_middleware { + for middleware in &extra_middleware.0 { + client = client.with_arc(middleware.clone()); + } + } + + client.build() + }; + let mut client = reqwest_middleware::ClientBuilder::new(client); // Avoid uncloneable errors with a streaming body during publish. @@ -484,22 +508,36 @@ impl<'a> BaseClientBuilder<'a> { client = client.with(retry_strategy); } + // When supplied, add the extra middleware. + if let Some(extra_middleware) = &self.extra_middleware { + for middleware in &extra_middleware.0 { + client = client.with_arc(middleware.clone()); + } + } + // Initialize the authentication middleware to set headers. match self.auth_integration { AuthIntegration::Default => { - let auth_middleware = AuthMiddleware::new() + let mut auth_middleware = AuthMiddleware::new() + .with_base_client(base_client) .with_indexes(self.indexes.clone()) .with_keyring(self.keyring.to_provider()) .with_preview(self.preview); + if let Ok(token_store) = PyxTokenStore::from_settings() { + auth_middleware = auth_middleware.with_pyx_token_store(token_store); + } client = client.with(auth_middleware); } AuthIntegration::OnlyAuthenticated => { - let auth_middleware = AuthMiddleware::new() + let mut auth_middleware = AuthMiddleware::new() + .with_base_client(base_client) .with_indexes(self.indexes.clone()) .with_keyring(self.keyring.to_provider()) .with_preview(self.preview) .with_only_authenticated(true); - + if let Ok(token_store) = PyxTokenStore::from_settings() { + auth_middleware = auth_middleware.with_pyx_token_store(token_store); + } client = client.with(auth_middleware); } AuthIntegration::NoAuthMiddleware => { @@ -507,13 +545,6 @@ impl<'a> BaseClientBuilder<'a> { } } - // When supplied add the extra middleware - if let Some(extra_middleware) = &self.extra_middleware { - for middleware in &extra_middleware.0 { - client = client.with_arc(middleware.clone()); - } - } - client.build() } Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client) diff --git a/crates/uv-state/src/lib.rs b/crates/uv-state/src/lib.rs index 46b10a817..e7c0f3469 100644 --- a/crates/uv-state/src/lib.rs +++ b/crates/uv-state/src/lib.rs @@ -105,7 +105,7 @@ pub enum StateBucket { ManagedPython, /// Installed tools. Tools, - /// Stored authentication credentials. + /// Credentials. Credentials, } diff --git a/crates/uv-static/src/env_vars.rs b/crates/uv-static/src/env_vars.rs index 4419a7eeb..cae0fd06f 100644 --- a/crates/uv-static/src/env_vars.rs +++ b/crates/uv-static/src/env_vars.rs @@ -856,4 +856,27 @@ impl EnvVars { /// Disable Hugging Face authentication, even if `HF_TOKEN` is set. pub const UV_NO_HF_TOKEN: &'static str = "UV_NO_HF_TOKEN"; + + /// The URL of the pyx Simple API server. + pub const PYX_API_URL: &'static str = "PYX_API_URL"; + + /// The domain of the pyx CDN. + pub const PYX_CDN_DOMAIN: &'static str = "PYX_CDN_DOMAIN"; + + /// The pyx API key (e.g., `sk-pyx-...`). + pub const PYX_API_KEY: &'static str = "PYX_API_KEY"; + + /// The pyx API key, for backwards compatibility. + #[attr_hidden] + pub const UV_API_KEY: &'static str = "UV_API_KEY"; + + /// The pyx authentication token (e.g., `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...`), as output by `uv auth token`. + pub const PYX_AUTH_TOKEN: &'static str = "PYX_AUTH_TOKEN"; + + /// The pyx authentication token, for backwards compatibility. + #[attr_hidden] + pub const UV_AUTH_TOKEN: &'static str = "UV_AUTH_TOKEN"; + + /// Specifies the directory where uv stores pyx credentials. + pub const PYX_CREDENTIALS_DIR: &'static str = "PYX_CREDENTIALS_DIR"; } diff --git a/crates/uv/Cargo.toml b/crates/uv/Cargo.toml index db55b1892..13f1949f7 100644 --- a/crates/uv/Cargo.toml +++ b/crates/uv/Cargo.toml @@ -69,6 +69,7 @@ axoupdater = { workspace = true, features = [ "github_releases", "tokio", ], optional = true } +base64 = { workspace = true } clap = { workspace = true, features = ["derive", "string", "wrap_help"] } console = { workspace = true } ctrlc = { workspace = true } @@ -84,6 +85,7 @@ indoc = { workspace = true } itertools = { workspace = true } h2 = { workspace = true } miette = { workspace = true, features = ["fancy-no-backtrace"] } +open = { workspace = true } owo-colors = { workspace = true } petgraph = { workspace = true } regex = { workspace = true } @@ -107,6 +109,7 @@ tracing-subscriber = { workspace = true, features = ["env-filter", "json", "regi tracing-tree = { workspace = true } unicode-width = { workspace = true } url = { workspace = true } +uuid = { workspace = true, features = ["v4"] } version-ranges = { workspace = true } walkdir = { workspace = true } which = { workspace = true } diff --git a/crates/uv/src/commands/auth/login.rs b/crates/uv/src/commands/auth/login.rs index c2c33c193..73c62fe80 100644 --- a/crates/uv/src/commands/auth/login.rs +++ b/crates/uv/src/commands/auth/login.rs @@ -3,15 +3,21 @@ use std::fmt::Write; use anyhow::{Result, bail}; use console::Term; use owo_colors::OwoColorize; +use url::Url; +use uuid::Uuid; -use uv_auth::Service; -use uv_auth::store::AuthBackend; -use uv_auth::{Credentials, TextCredentialStore}; +use uv_auth::{ + AccessToken, AuthBackend, Credentials, PyxOAuthTokens, PyxTokenStore, PyxTokens, Service, + TextCredentialStore, +}; +use uv_client::{AuthIntegration, BaseClient, BaseClientBuilder}; use uv_distribution_types::IndexUrl; use uv_pep508::VerbatimUrl; use uv_preview::Preview; -use crate::{commands::ExitStatus, printer::Printer}; +use crate::commands::ExitStatus; +use crate::printer::Printer; +use crate::settings::NetworkSettings; /// Login to a service. pub(crate) async fn login( @@ -19,9 +25,35 @@ pub(crate) async fn login( username: Option, password: Option, token: Option, + network_settings: &NetworkSettings, printer: Printer, preview: Preview, ) -> Result { + let pyx_store = PyxTokenStore::from_settings()?; + if pyx_store.is_known_domain(service.url()) { + if username.is_some() { + bail!("Cannot specify a username when logging in to pyx"); + } + if password.is_some() { + bail!("Cannot specify a password when logging in to pyx"); + } + + let client = BaseClientBuilder::default() + .connectivity(network_settings.connectivity) + .native_tls(network_settings.native_tls) + .allow_insecure_host(network_settings.allow_insecure_host.clone()) + .auth_integration(AuthIntegration::NoAuthMiddleware) + .build(); + + pyx_login_with_browser(&pyx_store, &client, &printer).await?; + writeln!( + printer.stderr(), + "Logged in to {}", + pyx_store.api().bold().cyan() + )?; + return Ok(ExitStatus::Success); + } + let backend = AuthBackend::from_settings(preview)?; // If the URL includes a known index URL suffix, strip it @@ -131,3 +163,64 @@ pub(crate) async fn login( )?; Ok(ExitStatus::Success) } + +/// Log in via the [`PyxTokenStore`]. +pub(crate) async fn pyx_login_with_browser( + store: &PyxTokenStore, + client: &BaseClient, + printer: &Printer, +) -> Result { + // Generate a login code, like `67e55044-10b1-426f-9247-bb680e5fe0c8`. + let cli_token = Uuid::new_v4(); + let url = { + let mut url = store.api().clone(); + url.set_path(&format!("auth/cli/login/{cli_token}")); + url + }; + match open::that(url.as_ref()) { + Ok(()) => { + writeln!(printer.stderr(), "Logging in with {}", url.cyan().bold())?; + } + Err(..) => { + writeln!( + printer.stderr(), + "Open the following URL in your browser: {}", + url.cyan().bold() + )?; + } + } + + // Poll the server for the login code. + let url = { + let mut url = store.api().clone(); + url.set_path(&format!("auth/cli/status/{cli_token}")); + url + }; + + let credentials = loop { + let response = client + .for_host(store.api()) + .get(Url::from(url.clone())) + .send() + .await?; + match response.status() { + // Retry on 404. + reqwest::StatusCode::NOT_FOUND => { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + // Parse the credentials on success. + _ if response.status().is_success() => { + let credentials = response.json::().await?; + break Ok::(PyxTokens::OAuth(credentials)); + } + // Fail on any other status code (like a 500). + status => { + break Err(anyhow::anyhow!("Failed to login with code `{status}`")); + } + } + }?; + + store.write(&credentials).await?; + + Ok(AccessToken::from(credentials)) +} diff --git a/crates/uv/src/commands/auth/logout.rs b/crates/uv/src/commands/auth/logout.rs index 58e2d11f1..6e64dda49 100644 --- a/crates/uv/src/commands/auth/logout.rs +++ b/crates/uv/src/commands/auth/logout.rs @@ -3,12 +3,13 @@ use std::fmt::Write; use anyhow::{Context, Result, bail}; use owo_colors::OwoColorize; -use uv_auth::store::AuthBackend; -use uv_auth::{Credentials, Service, TextCredentialStore, Username}; +use uv_auth::{AuthBackend, Credentials, PyxTokenStore, Service, TextCredentialStore, Username}; +use uv_client::BaseClientBuilder; use uv_distribution_types::IndexUrl; use uv_pep508::VerbatimUrl; use uv_preview::Preview; +use crate::settings::NetworkSettings; use crate::{commands::ExitStatus, printer::Printer}; /// Logout from a service. @@ -17,9 +18,15 @@ use crate::{commands::ExitStatus, printer::Printer}; pub(crate) async fn logout( service: Service, username: Option, + network_settings: &NetworkSettings, printer: Printer, preview: Preview, ) -> Result { + let pyx_store = PyxTokenStore::from_settings()?; + if pyx_store.is_known_domain(service.url()) { + return pyx_logout(&pyx_store, network_settings, printer).await; + } + let backend = AuthBackend::from_settings(preview)?; // TODO(zanieb): Use a shared abstraction across `login` and `logout`? @@ -79,3 +86,67 @@ pub(crate) async fn logout( Ok(ExitStatus::Success) } + +/// Log out via the [`PyxTokenStore`], invalidating the existing tokens. +async fn pyx_logout( + store: &PyxTokenStore, + network_settings: &NetworkSettings, + printer: Printer, +) -> Result { + // Initialize the client. + let client = BaseClientBuilder::default() + .connectivity(network_settings.connectivity) + .native_tls(network_settings.native_tls) + .allow_insecure_host(network_settings.allow_insecure_host.clone()) + .build(); + + // Retrieve the token store. + let Some(tokens) = store.read().await? else { + writeln!( + printer.stderr(), + "{}", + format_args!("No credentials found for {}", store.api().bold().cyan()) + )?; + return Ok(ExitStatus::Success); + }; + + // Add the token to the request. + let url = { + let mut url = store.api().clone(); + url.set_path("auth/cli/logout"); + url + }; + + // Build a basic request first, then authenticate it + let request = reqwest::Request::new(reqwest::Method::GET, url.into()); + let request = Credentials::from(tokens).authenticate(request); + + // Hit the logout endpoint using the client's execute method + let response = client.execute(request).await?; + match response.error_for_status_ref() { + Ok(..) => {} + Err(err) if matches!(err.status(), Some(reqwest::StatusCode::UNAUTHORIZED)) => { + tracing::debug!( + "Received 401 (Unauthorized) response from logout endpoint; removing tokens..." + ); + } + Err(err) => { + return Err(err.into()); + } + } + + // Remove the tokens from the store. + match store.delete().await { + Ok(..) => {} + Err(err) if matches!(err.kind(), std::io::ErrorKind::NotFound) => {} + Err(err) => return Err(err.into()), + } + + writeln!( + printer.stderr(), + "{}", + format_args!("Logged out from {}", store.api().bold().cyan()) + )?; + + Ok(ExitStatus::Success) +} diff --git a/crates/uv/src/commands/auth/token.rs b/crates/uv/src/commands/auth/token.rs index fa022c196..369b2dd55 100644 --- a/crates/uv/src/commands/auth/token.rs +++ b/crates/uv/src/commands/auth/token.rs @@ -1,21 +1,44 @@ use std::fmt::Write; use anyhow::{Result, bail}; +use owo_colors::OwoColorize; +use tracing::debug; -use uv_auth::Credentials; -use uv_auth::Service; -use uv_auth::store::AuthBackend; +use uv_auth::{AuthBackend, Service}; +use uv_auth::{Credentials, PyxTokenStore}; +use uv_client::{AuthIntegration, BaseClient, BaseClientBuilder}; use uv_preview::Preview; -use crate::{commands::ExitStatus, printer::Printer}; +use crate::commands::ExitStatus; +use crate::commands::auth::login; +use crate::printer::Printer; +use crate::settings::NetworkSettings; /// Show the token that will be used for a service. pub(crate) async fn token( service: Service, username: Option, + network_settings: &NetworkSettings, printer: Printer, preview: Preview, ) -> Result { + let pyx_store = PyxTokenStore::from_settings()?; + if pyx_store.is_known_domain(service.url()) { + if username.is_some() { + bail!("Cannot specify a username when logging in to pyx"); + } + + let client = BaseClientBuilder::default() + .connectivity(network_settings.connectivity) + .native_tls(network_settings.native_tls) + .allow_insecure_host(network_settings.allow_insecure_host.clone()) + .auth_integration(AuthIntegration::NoAuthMiddleware) + .build(); + + pyx_refresh(&pyx_store, &client, printer).await?; + return Ok(ExitStatus::Success); + } + let backend = AuthBackend::from_settings(preview)?; let url = service.url(); @@ -65,3 +88,36 @@ pub(crate) async fn token( writeln!(printer.stdout(), "{password}")?; Ok(ExitStatus::Success) } + +/// Refresh the authentication tokens in the [`PyxTokenStore`], prompting for login if necessary. +async fn pyx_refresh(store: &PyxTokenStore, client: &BaseClient, printer: Printer) -> Result<()> { + // Retrieve the token store. + let token = match store + .access_token(client.for_host(store.api()).raw_client(), 0) + .await + { + // If the tokens were successfully refreshed, return them. + Ok(Some(token)) => token, + + // If the token store is empty, prompt for login. + Ok(None) => { + debug!("Token store is empty; prompting for login..."); + login::pyx_login_with_browser(store, client, &printer).await? + } + + // Similarly, if the refresh token expired, prompt for login. + Err(err) if err.is_unauthorized() => { + debug!( + "Received 401 (Unauthorized) response from refresh endpoint; prompting for login..." + ); + login::pyx_login_with_browser(store, client, &printer).await? + } + + Err(err) => { + return Err(err.into()); + } + }; + + writeln!(printer.stdout(), "{}", token.cyan())?; + Ok(()) +} diff --git a/crates/uv/src/lib.rs b/crates/uv/src/lib.rs index 8016c7349..538f44538 100644 --- a/crates/uv/src/lib.rs +++ b/crates/uv/src/lib.rs @@ -443,7 +443,11 @@ async fn run(mut cli: Cli) -> Result { command: AuthCommand::Login(args), }) => { // Resolve the settings from the command-line arguments and workspace configuration. - let args = settings::AuthLoginSettings::resolve(args, filesystem); + let args = settings::AuthLoginSettings::resolve( + args, + &cli.top_level.global_args, + filesystem.as_ref(), + ); show_settings!(args); commands::auth_login( @@ -451,6 +455,7 @@ async fn run(mut cli: Cli) -> Result { args.username, args.password, args.token, + &args.network_settings, printer, globals.preview, ) @@ -460,19 +465,41 @@ async fn run(mut cli: Cli) -> Result { command: AuthCommand::Logout(args), }) => { // Resolve the settings from the command-line arguments and workspace configuration. - let args = settings::AuthLogoutSettings::resolve(args, filesystem); + let args = settings::AuthLogoutSettings::resolve( + args, + &cli.top_level.global_args, + filesystem.as_ref(), + ); show_settings!(args); - commands::auth_logout(args.service, args.username, printer, globals.preview).await + commands::auth_logout( + args.service, + args.username, + &args.network_settings, + printer, + globals.preview, + ) + .await } Commands::Auth(AuthNamespace { command: AuthCommand::Token(args), }) => { // Resolve the settings from the command-line arguments and workspace configuration. - let args = settings::AuthTokenSettings::resolve(args, filesystem); + let args = settings::AuthTokenSettings::resolve( + args, + &cli.top_level.global_args, + filesystem.as_ref(), + ); show_settings!(args); - commands::auth_token(args.service, args.username, printer, globals.preview).await + commands::auth_token( + args.service, + args.username, + &args.network_settings, + printer, + globals.preview, + ) + .await } Commands::Auth(AuthNamespace { command: AuthCommand::Dir, diff --git a/crates/uv/src/settings.rs b/crates/uv/src/settings.rs index 3d3cc55fc..9a742321a 100644 --- a/crates/uv/src/settings.rs +++ b/crates/uv/src/settings.rs @@ -3490,14 +3490,22 @@ impl PublishSettings { pub(crate) struct AuthLogoutSettings { pub(crate) service: Service, pub(crate) username: Option, + + // Both CLI and configuration. + pub(crate) network_settings: NetworkSettings, } impl AuthLogoutSettings { /// Resolve the [`AuthLogoutSettings`] from the CLI and filesystem configuration. - pub(crate) fn resolve(args: AuthLogoutArgs, _filesystem: Option) -> Self { + pub(crate) fn resolve( + args: AuthLogoutArgs, + global_args: &GlobalArgs, + filesystem: Option<&FilesystemOptions>, + ) -> Self { Self { service: args.service, username: args.username, + network_settings: NetworkSettings::resolve(global_args, filesystem), } } } @@ -3507,14 +3515,22 @@ impl AuthLogoutSettings { pub(crate) struct AuthTokenSettings { pub(crate) service: Service, pub(crate) username: Option, + + // Both CLI and configuration. + pub(crate) network_settings: NetworkSettings, } impl AuthTokenSettings { /// Resolve the [`AuthTokenSettings`] from the CLI and filesystem configuration. - pub(crate) fn resolve(args: AuthTokenArgs, _filesystem: Option) -> Self { + pub(crate) fn resolve( + args: AuthTokenArgs, + global_args: &GlobalArgs, + filesystem: Option<&FilesystemOptions>, + ) -> Self { Self { service: args.service, username: args.username, + network_settings: NetworkSettings::resolve(global_args, filesystem), } } } @@ -3526,16 +3542,24 @@ pub(crate) struct AuthLoginSettings { pub(crate) username: Option, pub(crate) password: Option, pub(crate) token: Option, + + // Both CLI and configuration. + pub(crate) network_settings: NetworkSettings, } impl AuthLoginSettings { /// Resolve the [`AuthLoginSettings`] from the CLI and filesystem configuration. - pub(crate) fn resolve(args: AuthLoginArgs, _filesystem: Option) -> Self { + pub(crate) fn resolve( + args: AuthLoginArgs, + global_args: &GlobalArgs, + filesystem: Option<&FilesystemOptions>, + ) -> Self { Self { service: args.service, username: args.username, password: args.password, token: args.token, + network_settings: NetworkSettings::resolve(global_args, filesystem), } } } diff --git a/docs/reference/environment.md b/docs/reference/environment.md index 00cd39760..47c82562e 100644 --- a/docs/reference/environment.md +++ b/docs/reference/environment.md @@ -685,6 +685,26 @@ See [`PycInvalidationMode`](https://docs.python.org/3/library/py_compile.html#py Adds directories to Python module search path (e.g., `PYTHONPATH=/path/to/modules`). +### `PYX_API_KEY` + +The pyx API key (e.g., `sk-pyx-...`). + +### `PYX_API_URL` + +The URL of the pyx Simple API server. + +### `PYX_AUTH_TOKEN` + +The pyx authentication token (e.g., `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...`), as output by `uv auth token`. + +### `PYX_CDN_DOMAIN` + +The domain of the pyx CDN. + +### `PYX_CREDENTIALS_DIR` + +Specifies the directory where uv stores pyx credentials. + ### `RUST_BACKTRACE` If set, it can be used to display more stack trace details when a panic occurs.