diff --git a/crates/uv-cli/src/lib.rs b/crates/uv-cli/src/lib.rs index ce763e0e5..4c2a60e57 100644 --- a/crates/uv-cli/src/lib.rs +++ b/crates/uv-cli/src/lib.rs @@ -5371,6 +5371,21 @@ pub struct ToolRunArgs { #[arg(long)] pub python_platform: Option, + /// The backend to use when fetching packages in the PyTorch ecosystem (e.g., `cpu`, `cu126`, or `auto`) + /// + /// When set, uv will ignore the configured index URLs for packages in the PyTorch ecosystem, + /// and will instead use the defined backend. + /// + /// For example, when set to `cpu`, uv will use the CPU-only PyTorch index; when set to `cu126`, + /// uv will use the PyTorch index for CUDA 12.6. + /// + /// The `auto` mode will attempt to detect the appropriate PyTorch index based on the currently + /// installed CUDA drivers. + /// + /// This option is in preview and may change in any future release. + #[arg(long, value_enum, env = EnvVars::UV_TORCH_BACKEND)] + pub torch_backend: Option, + #[arg(long, hide = true)] pub generate_shell_completion: Option, } @@ -5547,6 +5562,21 @@ pub struct ToolInstallArgs { /// `--python-platform` option is intended for advanced use cases. #[arg(long)] pub python_platform: Option, + + /// The backend to use when fetching packages in the PyTorch ecosystem (e.g., `cpu`, `cu126`, or `auto`) + /// + /// When set, uv will ignore the configured index URLs for packages in the PyTorch ecosystem, + /// and will instead use the defined backend. + /// + /// For example, when set to `cpu`, uv will use the CPU-only PyTorch index; when set to `cu126`, + /// uv will use the PyTorch index for CUDA 12.6. + /// + /// The `auto` mode will attempt to detect the appropriate PyTorch index based on the currently + /// installed CUDA drivers. + /// + /// This option is in preview and may change in any future release. + #[arg(long, value_enum, env = EnvVars::UV_TORCH_BACKEND)] + pub torch_backend: Option, } #[derive(Args)] diff --git a/crates/uv-cli/src/options.rs b/crates/uv-cli/src/options.rs index 1ed9ff851..f8d492f55 100644 --- a/crates/uv-cli/src/options.rs +++ b/crates/uv-cli/src/options.rs @@ -366,6 +366,7 @@ pub fn resolver_options( exclude_newer_package.unwrap_or_default(), ), link_mode, + torch_backend: None, no_build: flag(no_build, build, "build"), no_build_package: Some(no_build_package), no_binary: flag(no_binary, binary, "binary"), diff --git a/crates/uv-settings/src/settings.rs b/crates/uv-settings/src/settings.rs index 8d38c1607..9bf98ada9 100644 --- a/crates/uv-settings/src/settings.rs +++ b/crates/uv-settings/src/settings.rs @@ -370,6 +370,7 @@ pub struct ResolverOptions { pub config_settings_package: Option, pub exclude_newer: ExcludeNewer, pub link_mode: Option, + pub torch_backend: Option, pub upgrade: Option, pub build_isolation: Option, pub no_build: Option, @@ -404,6 +405,7 @@ pub struct ResolverInstallerOptions { pub exclude_newer: Option, pub exclude_newer_package: Option, pub link_mode: Option, + pub torch_backend: Option, pub compile_bytecode: Option, pub no_sources: Option, pub upgrade: Option, @@ -412,7 +414,6 @@ pub struct ResolverInstallerOptions { pub no_build_package: Option>, pub no_binary: Option, pub no_binary_package: Option>, - pub torch_backend: Option, } impl From for ResolverInstallerOptions { @@ -438,6 +439,7 @@ impl From for ResolverInstallerOptions { exclude_newer, exclude_newer_package, link_mode, + torch_backend, compile_bytecode, no_sources, upgrade, @@ -448,7 +450,6 @@ impl From for ResolverInstallerOptions { no_build_package, no_binary, no_binary_package, - torch_backend, } = value; Self { index, @@ -473,6 +474,7 @@ impl From for ResolverInstallerOptions { exclude_newer, exclude_newer_package, link_mode, + torch_backend, compile_bytecode, no_sources, upgrade: Upgrade::from_args( @@ -488,7 +490,6 @@ impl From for ResolverInstallerOptions { no_build_package, no_binary, no_binary_package, - torch_backend, } } } @@ -1925,6 +1926,7 @@ impl From for ResolverOptions { extra_build_dependencies: value.extra_build_dependencies, extra_build_variables: value.extra_build_variables, no_sources: value.no_sources, + torch_backend: value.torch_backend, } } } @@ -2004,6 +2006,7 @@ pub struct ToolOptions { pub no_build_package: Option>, pub no_binary: Option, pub no_binary_package: Option>, + pub torch_backend: Option, } impl From for ToolOptions { @@ -2034,6 +2037,7 @@ impl From for ToolOptions { no_build_package: value.no_build_package, no_binary: value.no_binary, no_binary_package: value.no_binary_package, + torch_backend: value.torch_backend, } } } @@ -2068,7 +2072,7 @@ impl From for ResolverInstallerOptions { no_build_package: value.no_build_package, no_binary: value.no_binary, no_binary_package: value.no_binary_package, - torch_backend: None, + torch_backend: value.torch_backend, } } } @@ -2150,7 +2154,7 @@ pub struct OptionsWire { // `crates/uv-workspace/src/pyproject.rs`. The documentation lives on that struct. // They're respected in both `pyproject.toml` and `uv.toml` files. override_dependencies: Option>>, - exclude_dependencies: Option>, + exclude_dependencies: Option>, constraint_dependencies: Option>>, build_constraint_dependencies: Option>>, environments: Option, diff --git a/crates/uv/src/commands/build_frontend.rs b/crates/uv/src/commands/build_frontend.rs index e19f51957..099ece65e 100644 --- a/crates/uv/src/commands/build_frontend.rs +++ b/crates/uv/src/commands/build_frontend.rs @@ -216,6 +216,7 @@ async fn build_impl( upgrade: _, build_options, sources, + torch_backend: _, } = settings; // Determine the source to build. diff --git a/crates/uv/src/commands/project/lock.rs b/crates/uv/src/commands/project/lock.rs index ef9537efa..9f56a0944 100644 --- a/crates/uv/src/commands/project/lock.rs +++ b/crates/uv/src/commands/project/lock.rs @@ -470,6 +470,7 @@ async fn do_lock( upgrade, build_options, sources, + torch_backend: _, } = settings; if !preview.is_enabled(PreviewFeatures::EXTRA_BUILD_DEPENDENCIES) diff --git a/crates/uv/src/commands/project/mod.rs b/crates/uv/src/commands/project/mod.rs index e3185957c..3e6b15e0e 100644 --- a/crates/uv/src/commands/project/mod.rs +++ b/crates/uv/src/commands/project/mod.rs @@ -43,6 +43,7 @@ use uv_resolver::{ use uv_scripts::Pep723ItemRef; use uv_settings::PythonInstallMirrors; use uv_static::EnvVars; +use uv_torch::{TorchSource, TorchStrategy}; use uv_types::{BuildIsolation, EmptyInstalledPackages, HashStrategy}; use uv_virtualenv::remove_virtualenv; use uv_warnings::{warn_user, warn_user_once}; @@ -278,6 +279,9 @@ pub(crate) enum ProjectError { #[error(transparent)] RetryParsing(#[from] uv_client::RetryParsingError), + #[error(transparent)] + Accelerator(#[from] uv_torch::AcceleratorError), + #[error(transparent)] Anyhow(#[from] anyhow::Error), } @@ -1723,6 +1727,7 @@ pub(crate) async fn resolve_names( prerelease: _, resolution: _, sources, + torch_backend, upgrade: _, }, compile_bytecode: _, @@ -1731,10 +1736,27 @@ pub(crate) async fn resolve_names( let client_builder = client_builder.clone().keyring(*keyring_provider); + // Determine the PyTorch backend. + let torch_backend = torch_backend + .map(|mode| { + let source = if uv_auth::PyxTokenStore::from_settings() + .is_ok_and(|store| store.has_credentials()) + { + TorchSource::Pyx + } else { + TorchSource::default() + }; + TorchStrategy::from_mode(mode, source, interpreter.platform().os()) + }) + .transpose() + .ok() + .flatten(); + // Initialize the registry client. let client = RegistryClientBuilder::new(client_builder, cache.clone()) .index_locations(index_locations.clone()) .index_strategy(*index_strategy) + .torch_backend(torch_backend.clone()) .markers(interpreter.markers()) .platform(interpreter.platform()) .build(); @@ -1880,6 +1902,7 @@ pub(crate) async fn resolve_environment( upgrade: _, build_options, sources, + torch_backend, } = settings; // Respect all requirements from the provided sources. @@ -1900,10 +1923,33 @@ pub(crate) async fn resolve_environment( let marker_env = pip::resolution_markers(None, python_platform, interpreter); let python_requirement = PythonRequirement::from_interpreter(interpreter); + // Determine the PyTorch backend. + let torch_backend = torch_backend + .map(|mode| { + let source = if uv_auth::PyxTokenStore::from_settings() + .is_ok_and(|store| store.has_credentials()) + { + TorchSource::Pyx + } else { + TorchSource::default() + }; + TorchStrategy::from_mode( + mode, + source, + python_platform + .map(|t| t.platform()) + .as_ref() + .unwrap_or(interpreter.platform()) + .os(), + ) + }) + .transpose()?; + // Initialize the registry client. let client = RegistryClientBuilder::new(client_builder, cache.clone()) .index_locations(index_locations.clone()) .index_strategy(*index_strategy) + .torch_backend(torch_backend.clone()) .markers(interpreter.markers()) .platform(interpreter.platform()) .build(); @@ -2232,6 +2278,7 @@ pub(crate) async fn update_environment( prerelease, resolution, sources, + torch_backend, upgrade, }, compile_bytecode, @@ -2302,10 +2349,33 @@ pub(crate) async fn update_environment( } } + // Determine the PyTorch backend. + let torch_backend = torch_backend + .map(|mode| { + let source = if uv_auth::PyxTokenStore::from_settings() + .is_ok_and(|store| store.has_credentials()) + { + TorchSource::Pyx + } else { + TorchSource::default() + }; + TorchStrategy::from_mode( + mode, + source, + python_platform + .map(|t| t.platform()) + .as_ref() + .unwrap_or(interpreter.platform()) + .os(), + ) + }) + .transpose()?; + // Initialize the registry client. let client = RegistryClientBuilder::new(client_builder, cache.clone()) .index_locations(index_locations.clone()) .index_strategy(*index_strategy) + .torch_backend(torch_backend.clone()) .markers(interpreter.markers()) .platform(interpreter.platform()) .build(); diff --git a/crates/uv/src/commands/project/sync.rs b/crates/uv/src/commands/project/sync.rs index 2542965e8..e7a141940 100644 --- a/crates/uv/src/commands/project/sync.rs +++ b/crates/uv/src/commands/project/sync.rs @@ -676,6 +676,7 @@ pub(super) async fn do_sync( prerelease: PrereleaseMode::default(), resolution: ResolutionMode::default(), sources, + torch_backend: None, upgrade: Upgrade::default(), }; script_extra_build_requires( diff --git a/crates/uv/src/commands/project/tree.rs b/crates/uv/src/commands/project/tree.rs index 4ee2e80f8..537216d42 100644 --- a/crates/uv/src/commands/project/tree.rs +++ b/crates/uv/src/commands/project/tree.rs @@ -212,6 +212,7 @@ pub(crate) async fn tree( upgrade: _, build_options: _, sources: _, + torch_backend: _, } = &settings; let capabilities = IndexCapabilities::default(); diff --git a/crates/uv/src/commands/tool/install.rs b/crates/uv/src/commands/tool/install.rs index 394fe93ff..48a7fd9cc 100644 --- a/crates/uv/src/commands/tool/install.rs +++ b/crates/uv/src/commands/tool/install.rs @@ -28,7 +28,7 @@ use uv_python::{ use uv_requirements::{RequirementsSource, RequirementsSpecification}; use uv_settings::{PythonInstallMirrors, ResolverInstallerOptions, ToolOptions}; use uv_tool::InstalledTools; -use uv_warnings::warn_user; +use uv_warnings::{warn_user, warn_user_once}; use uv_workspace::WorkspaceCache; use crate::commands::ExitStatus; @@ -76,6 +76,12 @@ pub(crate) async fn install( printer: Printer, preview: Preview, ) -> Result { + if settings.resolver.torch_backend.is_some() { + warn_user_once!( + "The `--torch-backend` option is experimental and may change without warning." + ); + } + let reporter = PythonDownloadReporter::single(printer); let python_request = python.as_deref().map(PythonRequest::parse); diff --git a/crates/uv/src/commands/tool/run.rs b/crates/uv/src/commands/tool/run.rs index 6eda5de02..60e96f3f8 100644 --- a/crates/uv/src/commands/tool/run.rs +++ b/crates/uv/src/commands/tool/run.rs @@ -129,6 +129,12 @@ pub(crate) async fn run( .is_some_and(|ext| ext.eq_ignore_ascii_case("py") || ext.eq_ignore_ascii_case("pyw")) } + if settings.resolver.torch_backend.is_some() { + warn_user_once!( + "The `--torch-backend` option is experimental and may change without warning." + ); + } + // Read from the `.env` file, if necessary. if !no_env_file { for env_file_path in env_file.iter().rev().map(PathBuf::as_path) { diff --git a/crates/uv/src/settings.rs b/crates/uv/src/settings.rs index 204ce41b2..53a7dd0cd 100644 --- a/crates/uv/src/settings.rs +++ b/crates/uv/src/settings.rs @@ -586,6 +586,7 @@ impl ToolRunSettings { lfs, python, python_platform, + torch_backend, generate_shell_completion: _, } = args; @@ -615,21 +616,24 @@ impl ToolRunSettings { } } + let filesystem_options = filesystem.map(FilesystemOptions::into_options); + let options = resolver_installer_options(installer, build).combine(ResolverInstallerOptions::from( - filesystem - .clone() - .map(FilesystemOptions::into_options) - .map(|options| options.top_level) + filesystem_options + .as_ref() + .map(|options| options.top_level.clone()) .unwrap_or_default(), )); - let filesystem_install_mirrors = filesystem - .map(FilesystemOptions::into_options) - .map(|options| options.install_mirrors) + let filesystem_install_mirrors = filesystem_options + .map(|options| options.install_mirrors.clone()) .unwrap_or_default(); - let settings = ResolverInstallerSettings::from(options.clone()); + let mut settings = ResolverInstallerSettings::from(options.clone()); + if torch_backend.is_some() { + settings.resolver.torch_backend = torch_backend; + } let lfs = GitLfsSetting::new(lfs.then_some(true), environment.lfs); Self { @@ -727,23 +731,27 @@ impl ToolInstallSettings { refresh, python, python_platform, + torch_backend, } = args; + let filesystem_options = filesystem.map(FilesystemOptions::into_options); + let options = resolver_installer_options(installer, build).combine(ResolverInstallerOptions::from( - filesystem - .clone() - .map(FilesystemOptions::into_options) - .map(|options| options.top_level) + filesystem_options + .as_ref() + .map(|options| options.top_level.clone()) .unwrap_or_default(), )); - let filesystem_install_mirrors = filesystem - .map(FilesystemOptions::into_options) - .map(|options| options.install_mirrors) + let filesystem_install_mirrors = filesystem_options + .map(|options| options.install_mirrors.clone()) .unwrap_or_default(); - let settings = ResolverInstallerSettings::from(options.clone()); + let mut settings = ResolverInstallerSettings::from(options.clone()); + if torch_backend.is_some() { + settings.resolver.torch_backend = torch_backend; + } let lfs = GitLfsSetting::new(lfs.then_some(true), environment.lfs); Self { @@ -3199,6 +3207,7 @@ pub(crate) struct ResolverSettings { pub(crate) prerelease: PrereleaseMode, pub(crate) resolution: ResolutionMode, pub(crate) sources: SourceStrategy, + pub(crate) torch_backend: Option, pub(crate) upgrade: Upgrade, } @@ -3253,6 +3262,7 @@ impl From for ResolverSettings { extra_build_variables: value.extra_build_variables.unwrap_or_default(), exclude_newer: value.exclude_newer, link_mode: value.link_mode.unwrap_or_default(), + torch_backend: value.torch_backend, sources: SourceStrategy::from_args(value.no_sources.unwrap_or_default()), upgrade: value.upgrade.unwrap_or_default(), build_options: BuildOptions::new( @@ -3344,6 +3354,7 @@ impl From for ResolverInstallerSettings { prerelease: value.prerelease.unwrap_or_default(), resolution: value.resolution.unwrap_or_default(), sources: SourceStrategy::from_args(value.no_sources.unwrap_or_default()), + torch_backend: value.torch_backend, upgrade: value.upgrade.unwrap_or_default(), }, compile_bytecode: value.compile_bytecode.unwrap_or_default(), diff --git a/crates/uv/tests/it/show_settings.rs b/crates/uv/tests/it/show_settings.rs index bfd03da94..f951d754d 100644 --- a/crates/uv/tests/it/show_settings.rs +++ b/crates/uv/tests/it/show_settings.rs @@ -3565,6 +3565,7 @@ fn resolve_tool() -> anyhow::Result<()> { link_mode: Some( Clone, ), + torch_backend: None, compile_bytecode: None, no_sources: None, upgrade: None, @@ -3573,7 +3574,6 @@ fn resolve_tool() -> anyhow::Result<()> { no_build_package: None, no_binary: None, no_binary_package: None, - torch_backend: None, }, settings: ResolverInstallerSettings { resolver: ResolverSettings { @@ -3615,6 +3615,7 @@ fn resolve_tool() -> anyhow::Result<()> { prerelease: IfNecessaryOrExplicit, resolution: LowestDirect, sources: Enabled, + torch_backend: None, upgrade: None, }, compile_bytecode: false, @@ -7912,6 +7913,7 @@ fn preview_features() { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, compile_bytecode: false, @@ -8026,6 +8028,7 @@ fn preview_features() { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, compile_bytecode: false, @@ -8140,6 +8143,7 @@ fn preview_features() { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, compile_bytecode: false, @@ -8254,6 +8258,7 @@ fn preview_features() { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, compile_bytecode: false, @@ -8368,6 +8373,7 @@ fn preview_features() { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, compile_bytecode: false, @@ -8484,6 +8490,7 @@ fn preview_features() { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, compile_bytecode: false, @@ -9738,6 +9745,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, } @@ -9857,6 +9865,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: Packages( { PackageName( @@ -9999,6 +10008,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: All, }, } @@ -10116,6 +10126,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: None, }, } @@ -10223,6 +10234,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: All, }, } @@ -10331,6 +10343,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> { prerelease: IfNecessaryOrExplicit, resolution: Highest, sources: Enabled, + torch_backend: None, upgrade: Packages( { PackageName(