Respect `--torch-backend` in `uv tool` commands (#17117)

## Summary

Like `uv pip`, these don't require a universal resolution, so
`--torch-backend` is easy to support.
This commit is contained in:
Charlie Marsh 2025-12-16 19:23:50 -05:00 committed by GitHub
parent e603761862
commit 0a83bf7dd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 168 additions and 23 deletions

View File

@ -5371,6 +5371,21 @@ pub struct ToolRunArgs {
#[arg(long)]
pub python_platform: Option<TargetTriple>,
/// The backend to use when fetching packages in the PyTorch ecosystem (e.g., `cpu`, `cu126`, or `auto`)
///
/// When set, uv will ignore the configured index URLs for packages in the PyTorch ecosystem,
/// and will instead use the defined backend.
///
/// For example, when set to `cpu`, uv will use the CPU-only PyTorch index; when set to `cu126`,
/// uv will use the PyTorch index for CUDA 12.6.
///
/// The `auto` mode will attempt to detect the appropriate PyTorch index based on the currently
/// installed CUDA drivers.
///
/// This option is in preview and may change in any future release.
#[arg(long, value_enum, env = EnvVars::UV_TORCH_BACKEND)]
pub torch_backend: Option<TorchMode>,
#[arg(long, hide = true)]
pub generate_shell_completion: Option<clap_complete_command::Shell>,
}
@ -5547,6 +5562,21 @@ pub struct ToolInstallArgs {
/// `--python-platform` option is intended for advanced use cases.
#[arg(long)]
pub python_platform: Option<TargetTriple>,
/// The backend to use when fetching packages in the PyTorch ecosystem (e.g., `cpu`, `cu126`, or `auto`)
///
/// When set, uv will ignore the configured index URLs for packages in the PyTorch ecosystem,
/// and will instead use the defined backend.
///
/// For example, when set to `cpu`, uv will use the CPU-only PyTorch index; when set to `cu126`,
/// uv will use the PyTorch index for CUDA 12.6.
///
/// The `auto` mode will attempt to detect the appropriate PyTorch index based on the currently
/// installed CUDA drivers.
///
/// This option is in preview and may change in any future release.
#[arg(long, value_enum, env = EnvVars::UV_TORCH_BACKEND)]
pub torch_backend: Option<TorchMode>,
}
#[derive(Args)]

View File

@ -366,6 +366,7 @@ pub fn resolver_options(
exclude_newer_package.unwrap_or_default(),
),
link_mode,
torch_backend: None,
no_build: flag(no_build, build, "build"),
no_build_package: Some(no_build_package),
no_binary: flag(no_binary, binary, "binary"),

View File

@ -370,6 +370,7 @@ pub struct ResolverOptions {
pub config_settings_package: Option<PackageConfigSettings>,
pub exclude_newer: ExcludeNewer,
pub link_mode: Option<LinkMode>,
pub torch_backend: Option<TorchMode>,
pub upgrade: Option<Upgrade>,
pub build_isolation: Option<BuildIsolation>,
pub no_build: Option<bool>,
@ -404,6 +405,7 @@ pub struct ResolverInstallerOptions {
pub exclude_newer: Option<ExcludeNewerValue>,
pub exclude_newer_package: Option<ExcludeNewerPackage>,
pub link_mode: Option<LinkMode>,
pub torch_backend: Option<TorchMode>,
pub compile_bytecode: Option<bool>,
pub no_sources: Option<bool>,
pub upgrade: Option<Upgrade>,
@ -412,7 +414,6 @@ pub struct ResolverInstallerOptions {
pub no_build_package: Option<Vec<PackageName>>,
pub no_binary: Option<bool>,
pub no_binary_package: Option<Vec<PackageName>>,
pub torch_backend: Option<TorchMode>,
}
impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
@ -438,6 +439,7 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
exclude_newer,
exclude_newer_package,
link_mode,
torch_backend,
compile_bytecode,
no_sources,
upgrade,
@ -448,7 +450,6 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
no_build_package,
no_binary,
no_binary_package,
torch_backend,
} = value;
Self {
index,
@ -473,6 +474,7 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
exclude_newer,
exclude_newer_package,
link_mode,
torch_backend,
compile_bytecode,
no_sources,
upgrade: Upgrade::from_args(
@ -488,7 +490,6 @@ impl From<ResolverInstallerSchema> for ResolverInstallerOptions {
no_build_package,
no_binary,
no_binary_package,
torch_backend,
}
}
}
@ -1925,6 +1926,7 @@ impl From<ResolverInstallerSchema> for ResolverOptions {
extra_build_dependencies: value.extra_build_dependencies,
extra_build_variables: value.extra_build_variables,
no_sources: value.no_sources,
torch_backend: value.torch_backend,
}
}
}
@ -2004,6 +2006,7 @@ pub struct ToolOptions {
pub no_build_package: Option<Vec<PackageName>>,
pub no_binary: Option<bool>,
pub no_binary_package: Option<Vec<PackageName>>,
pub torch_backend: Option<TorchMode>,
}
impl From<ResolverInstallerOptions> for ToolOptions {
@ -2034,6 +2037,7 @@ impl From<ResolverInstallerOptions> for ToolOptions {
no_build_package: value.no_build_package,
no_binary: value.no_binary,
no_binary_package: value.no_binary_package,
torch_backend: value.torch_backend,
}
}
}
@ -2068,7 +2072,7 @@ impl From<ToolOptions> for ResolverInstallerOptions {
no_build_package: value.no_build_package,
no_binary: value.no_binary,
no_binary_package: value.no_binary_package,
torch_backend: None,
torch_backend: value.torch_backend,
}
}
}
@ -2150,7 +2154,7 @@ pub struct OptionsWire {
// `crates/uv-workspace/src/pyproject.rs`. The documentation lives on that struct.
// They're respected in both `pyproject.toml` and `uv.toml` files.
override_dependencies: Option<Vec<Requirement<VerbatimParsedUrl>>>,
exclude_dependencies: Option<Vec<uv_normalize::PackageName>>,
exclude_dependencies: Option<Vec<PackageName>>,
constraint_dependencies: Option<Vec<Requirement<VerbatimParsedUrl>>>,
build_constraint_dependencies: Option<Vec<Requirement<VerbatimParsedUrl>>>,
environments: Option<SupportedEnvironments>,

View File

@ -216,6 +216,7 @@ async fn build_impl(
upgrade: _,
build_options,
sources,
torch_backend: _,
} = settings;
// Determine the source to build.

View File

@ -470,6 +470,7 @@ async fn do_lock(
upgrade,
build_options,
sources,
torch_backend: _,
} = settings;
if !preview.is_enabled(PreviewFeatures::EXTRA_BUILD_DEPENDENCIES)

View File

@ -43,6 +43,7 @@ use uv_resolver::{
use uv_scripts::Pep723ItemRef;
use uv_settings::PythonInstallMirrors;
use uv_static::EnvVars;
use uv_torch::{TorchSource, TorchStrategy};
use uv_types::{BuildIsolation, EmptyInstalledPackages, HashStrategy};
use uv_virtualenv::remove_virtualenv;
use uv_warnings::{warn_user, warn_user_once};
@ -278,6 +279,9 @@ pub(crate) enum ProjectError {
#[error(transparent)]
RetryParsing(#[from] uv_client::RetryParsingError),
#[error(transparent)]
Accelerator(#[from] uv_torch::AcceleratorError),
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
}
@ -1723,6 +1727,7 @@ pub(crate) async fn resolve_names(
prerelease: _,
resolution: _,
sources,
torch_backend,
upgrade: _,
},
compile_bytecode: _,
@ -1731,10 +1736,27 @@ pub(crate) async fn resolve_names(
let client_builder = client_builder.clone().keyring(*keyring_provider);
// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(mode, source, interpreter.platform().os())
})
.transpose()
.ok()
.flatten();
// Initialize the registry client.
let client = RegistryClientBuilder::new(client_builder, cache.clone())
.index_locations(index_locations.clone())
.index_strategy(*index_strategy)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
@ -1880,6 +1902,7 @@ pub(crate) async fn resolve_environment(
upgrade: _,
build_options,
sources,
torch_backend,
} = settings;
// Respect all requirements from the provided sources.
@ -1900,10 +1923,33 @@ pub(crate) async fn resolve_environment(
let marker_env = pip::resolution_markers(None, python_platform, interpreter);
let python_requirement = PythonRequirement::from_interpreter(interpreter);
// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(
mode,
source,
python_platform
.map(|t| t.platform())
.as_ref()
.unwrap_or(interpreter.platform())
.os(),
)
})
.transpose()?;
// Initialize the registry client.
let client = RegistryClientBuilder::new(client_builder, cache.clone())
.index_locations(index_locations.clone())
.index_strategy(*index_strategy)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();
@ -2232,6 +2278,7 @@ pub(crate) async fn update_environment(
prerelease,
resolution,
sources,
torch_backend,
upgrade,
},
compile_bytecode,
@ -2302,10 +2349,33 @@ pub(crate) async fn update_environment(
}
}
// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(
mode,
source,
python_platform
.map(|t| t.platform())
.as_ref()
.unwrap_or(interpreter.platform())
.os(),
)
})
.transpose()?;
// Initialize the registry client.
let client = RegistryClientBuilder::new(client_builder, cache.clone())
.index_locations(index_locations.clone())
.index_strategy(*index_strategy)
.torch_backend(torch_backend.clone())
.markers(interpreter.markers())
.platform(interpreter.platform())
.build();

View File

@ -676,6 +676,7 @@ pub(super) async fn do_sync(
prerelease: PrereleaseMode::default(),
resolution: ResolutionMode::default(),
sources,
torch_backend: None,
upgrade: Upgrade::default(),
};
script_extra_build_requires(

View File

@ -212,6 +212,7 @@ pub(crate) async fn tree(
upgrade: _,
build_options: _,
sources: _,
torch_backend: _,
} = &settings;
let capabilities = IndexCapabilities::default();

View File

@ -28,7 +28,7 @@ use uv_python::{
use uv_requirements::{RequirementsSource, RequirementsSpecification};
use uv_settings::{PythonInstallMirrors, ResolverInstallerOptions, ToolOptions};
use uv_tool::InstalledTools;
use uv_warnings::warn_user;
use uv_warnings::{warn_user, warn_user_once};
use uv_workspace::WorkspaceCache;
use crate::commands::ExitStatus;
@ -76,6 +76,12 @@ pub(crate) async fn install(
printer: Printer,
preview: Preview,
) -> Result<ExitStatus> {
if settings.resolver.torch_backend.is_some() {
warn_user_once!(
"The `--torch-backend` option is experimental and may change without warning."
);
}
let reporter = PythonDownloadReporter::single(printer);
let python_request = python.as_deref().map(PythonRequest::parse);

View File

@ -129,6 +129,12 @@ pub(crate) async fn run(
.is_some_and(|ext| ext.eq_ignore_ascii_case("py") || ext.eq_ignore_ascii_case("pyw"))
}
if settings.resolver.torch_backend.is_some() {
warn_user_once!(
"The `--torch-backend` option is experimental and may change without warning."
);
}
// Read from the `.env` file, if necessary.
if !no_env_file {
for env_file_path in env_file.iter().rev().map(PathBuf::as_path) {

View File

@ -586,6 +586,7 @@ impl ToolRunSettings {
lfs,
python,
python_platform,
torch_backend,
generate_shell_completion: _,
} = args;
@ -615,21 +616,24 @@ impl ToolRunSettings {
}
}
let filesystem_options = filesystem.map(FilesystemOptions::into_options);
let options =
resolver_installer_options(installer, build).combine(ResolverInstallerOptions::from(
filesystem
.clone()
.map(FilesystemOptions::into_options)
.map(|options| options.top_level)
filesystem_options
.as_ref()
.map(|options| options.top_level.clone())
.unwrap_or_default(),
));
let filesystem_install_mirrors = filesystem
.map(FilesystemOptions::into_options)
.map(|options| options.install_mirrors)
let filesystem_install_mirrors = filesystem_options
.map(|options| options.install_mirrors.clone())
.unwrap_or_default();
let settings = ResolverInstallerSettings::from(options.clone());
let mut settings = ResolverInstallerSettings::from(options.clone());
if torch_backend.is_some() {
settings.resolver.torch_backend = torch_backend;
}
let lfs = GitLfsSetting::new(lfs.then_some(true), environment.lfs);
Self {
@ -727,23 +731,27 @@ impl ToolInstallSettings {
refresh,
python,
python_platform,
torch_backend,
} = args;
let filesystem_options = filesystem.map(FilesystemOptions::into_options);
let options =
resolver_installer_options(installer, build).combine(ResolverInstallerOptions::from(
filesystem
.clone()
.map(FilesystemOptions::into_options)
.map(|options| options.top_level)
filesystem_options
.as_ref()
.map(|options| options.top_level.clone())
.unwrap_or_default(),
));
let filesystem_install_mirrors = filesystem
.map(FilesystemOptions::into_options)
.map(|options| options.install_mirrors)
let filesystem_install_mirrors = filesystem_options
.map(|options| options.install_mirrors.clone())
.unwrap_or_default();
let settings = ResolverInstallerSettings::from(options.clone());
let mut settings = ResolverInstallerSettings::from(options.clone());
if torch_backend.is_some() {
settings.resolver.torch_backend = torch_backend;
}
let lfs = GitLfsSetting::new(lfs.then_some(true), environment.lfs);
Self {
@ -3199,6 +3207,7 @@ pub(crate) struct ResolverSettings {
pub(crate) prerelease: PrereleaseMode,
pub(crate) resolution: ResolutionMode,
pub(crate) sources: SourceStrategy,
pub(crate) torch_backend: Option<TorchMode>,
pub(crate) upgrade: Upgrade,
}
@ -3253,6 +3262,7 @@ impl From<ResolverOptions> for ResolverSettings {
extra_build_variables: value.extra_build_variables.unwrap_or_default(),
exclude_newer: value.exclude_newer,
link_mode: value.link_mode.unwrap_or_default(),
torch_backend: value.torch_backend,
sources: SourceStrategy::from_args(value.no_sources.unwrap_or_default()),
upgrade: value.upgrade.unwrap_or_default(),
build_options: BuildOptions::new(
@ -3344,6 +3354,7 @@ impl From<ResolverInstallerOptions> for ResolverInstallerSettings {
prerelease: value.prerelease.unwrap_or_default(),
resolution: value.resolution.unwrap_or_default(),
sources: SourceStrategy::from_args(value.no_sources.unwrap_or_default()),
torch_backend: value.torch_backend,
upgrade: value.upgrade.unwrap_or_default(),
},
compile_bytecode: value.compile_bytecode.unwrap_or_default(),

View File

@ -3565,6 +3565,7 @@ fn resolve_tool() -> anyhow::Result<()> {
link_mode: Some(
Clone,
),
torch_backend: None,
compile_bytecode: None,
no_sources: None,
upgrade: None,
@ -3573,7 +3574,6 @@ fn resolve_tool() -> anyhow::Result<()> {
no_build_package: None,
no_binary: None,
no_binary_package: None,
torch_backend: None,
},
settings: ResolverInstallerSettings {
resolver: ResolverSettings {
@ -3615,6 +3615,7 @@ fn resolve_tool() -> anyhow::Result<()> {
prerelease: IfNecessaryOrExplicit,
resolution: LowestDirect,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
compile_bytecode: false,
@ -7912,6 +7913,7 @@ fn preview_features() {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
compile_bytecode: false,
@ -8026,6 +8028,7 @@ fn preview_features() {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
compile_bytecode: false,
@ -8140,6 +8143,7 @@ fn preview_features() {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
compile_bytecode: false,
@ -8254,6 +8258,7 @@ fn preview_features() {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
compile_bytecode: false,
@ -8368,6 +8373,7 @@ fn preview_features() {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
compile_bytecode: false,
@ -8484,6 +8490,7 @@ fn preview_features() {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
compile_bytecode: false,
@ -9738,6 +9745,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
}
@ -9857,6 +9865,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: Packages(
{
PackageName(
@ -9999,6 +10008,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: All,
},
}
@ -10116,6 +10126,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: None,
},
}
@ -10223,6 +10234,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: All,
},
}
@ -10331,6 +10343,7 @@ fn upgrade_project_cli_config_interaction() -> anyhow::Result<()> {
prerelease: IfNecessaryOrExplicit,
resolution: Highest,
sources: Enabled,
torch_backend: None,
upgrade: Packages(
{
PackageName(