diff --git a/crates/uv-torch/src/backend.rs b/crates/uv-torch/src/backend.rs index 3ad4a62ac..7b3bca27a 100644 --- a/crates/uv-torch/src/backend.rs +++ b/crates/uv-torch/src/backend.rs @@ -36,6 +36,7 @@ //! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. //! ``` +use std::borrow::Cow; use std::str::FromStr; use std::sync::LazyLock; @@ -46,6 +47,7 @@ use uv_distribution_types::IndexUrl; use uv_normalize::PackageName; use uv_pep440::Version; use uv_platform_tags::Os; +use uv_static::EnvVars; use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture}; @@ -177,104 +179,167 @@ pub enum TorchMode { Xpu, } +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +pub enum TorchSource { + /// Download PyTorch builds from the official PyTorch index. + #[default] + PyTorch, + /// Download PyTorch builds from the pyx index. + Pyx, +} + /// The strategy to use when determining the appropriate PyTorch index. #[derive(Debug, Clone, Eq, PartialEq)] pub enum TorchStrategy { /// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`). - Cuda { os: Os, driver_version: Version }, + Cuda { + os: Os, + driver_version: Version, + source: TorchSource, + }, /// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`). Amd { os: Os, gpu_architecture: AmdGpuArchitecture, + source: TorchSource, }, /// Select the appropriate PyTorch index based on the operating system and Intel GPU presence. - Xpu { os: Os }, + Xpu { os: Os, source: TorchSource }, /// Use the specified PyTorch index. - Backend(TorchBackend), + Backend { + backend: TorchBackend, + source: TorchSource, + }, } impl TorchStrategy { /// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`]. - pub fn from_mode(mode: TorchMode, os: &Os) -> Result { - match mode { + pub fn from_mode( + mode: TorchMode, + source: TorchSource, + os: &Os, + ) -> Result { + let backend = match mode { TorchMode::Auto => match Accelerator::detect()? { - Some(Accelerator::Cuda { driver_version }) => Ok(Self::Cuda { - os: os.clone(), - driver_version: driver_version.clone(), - }), - Some(Accelerator::Amd { gpu_architecture }) => Ok(Self::Amd { - os: os.clone(), - gpu_architecture, - }), - Some(Accelerator::Xpu) => Ok(Self::Xpu { os: os.clone() }), - None => Ok(Self::Backend(TorchBackend::Cpu)), + Some(Accelerator::Cuda { driver_version }) => { + return Ok(Self::Cuda { + os: os.clone(), + driver_version: driver_version.clone(), + source, + }); + } + Some(Accelerator::Amd { gpu_architecture }) => { + return Ok(Self::Amd { + os: os.clone(), + gpu_architecture, + source, + }); + } + Some(Accelerator::Xpu) => { + return Ok(Self::Xpu { + os: os.clone(), + source, + }); + } + None => TorchBackend::Cpu, }, - TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)), - TorchMode::Cu129 => Ok(Self::Backend(TorchBackend::Cu129)), - TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)), - TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)), - TorchMode::Cu125 => Ok(Self::Backend(TorchBackend::Cu125)), - TorchMode::Cu124 => Ok(Self::Backend(TorchBackend::Cu124)), - TorchMode::Cu123 => Ok(Self::Backend(TorchBackend::Cu123)), - TorchMode::Cu122 => Ok(Self::Backend(TorchBackend::Cu122)), - TorchMode::Cu121 => Ok(Self::Backend(TorchBackend::Cu121)), - TorchMode::Cu120 => Ok(Self::Backend(TorchBackend::Cu120)), - TorchMode::Cu118 => Ok(Self::Backend(TorchBackend::Cu118)), - TorchMode::Cu117 => Ok(Self::Backend(TorchBackend::Cu117)), - TorchMode::Cu116 => Ok(Self::Backend(TorchBackend::Cu116)), - TorchMode::Cu115 => Ok(Self::Backend(TorchBackend::Cu115)), - TorchMode::Cu114 => Ok(Self::Backend(TorchBackend::Cu114)), - TorchMode::Cu113 => Ok(Self::Backend(TorchBackend::Cu113)), - TorchMode::Cu112 => Ok(Self::Backend(TorchBackend::Cu112)), - TorchMode::Cu111 => Ok(Self::Backend(TorchBackend::Cu111)), - TorchMode::Cu110 => Ok(Self::Backend(TorchBackend::Cu110)), - TorchMode::Cu102 => Ok(Self::Backend(TorchBackend::Cu102)), - TorchMode::Cu101 => Ok(Self::Backend(TorchBackend::Cu101)), - TorchMode::Cu100 => Ok(Self::Backend(TorchBackend::Cu100)), - TorchMode::Cu92 => Ok(Self::Backend(TorchBackend::Cu92)), - TorchMode::Cu91 => Ok(Self::Backend(TorchBackend::Cu91)), - TorchMode::Cu90 => Ok(Self::Backend(TorchBackend::Cu90)), - TorchMode::Cu80 => Ok(Self::Backend(TorchBackend::Cu80)), - TorchMode::Rocm63 => Ok(Self::Backend(TorchBackend::Rocm63)), - TorchMode::Rocm624 => Ok(Self::Backend(TorchBackend::Rocm624)), - TorchMode::Rocm62 => Ok(Self::Backend(TorchBackend::Rocm62)), - TorchMode::Rocm61 => Ok(Self::Backend(TorchBackend::Rocm61)), - TorchMode::Rocm60 => Ok(Self::Backend(TorchBackend::Rocm60)), - TorchMode::Rocm57 => Ok(Self::Backend(TorchBackend::Rocm57)), - TorchMode::Rocm56 => Ok(Self::Backend(TorchBackend::Rocm56)), - TorchMode::Rocm55 => Ok(Self::Backend(TorchBackend::Rocm55)), - TorchMode::Rocm542 => Ok(Self::Backend(TorchBackend::Rocm542)), - TorchMode::Rocm54 => Ok(Self::Backend(TorchBackend::Rocm54)), - TorchMode::Rocm53 => Ok(Self::Backend(TorchBackend::Rocm53)), - TorchMode::Rocm52 => Ok(Self::Backend(TorchBackend::Rocm52)), - TorchMode::Rocm511 => Ok(Self::Backend(TorchBackend::Rocm511)), - TorchMode::Rocm42 => Ok(Self::Backend(TorchBackend::Rocm42)), - TorchMode::Rocm41 => Ok(Self::Backend(TorchBackend::Rocm41)), - TorchMode::Rocm401 => Ok(Self::Backend(TorchBackend::Rocm401)), - TorchMode::Xpu => Ok(Self::Backend(TorchBackend::Xpu)), - } + TorchMode::Cpu => TorchBackend::Cpu, + TorchMode::Cu129 => TorchBackend::Cu129, + TorchMode::Cu128 => TorchBackend::Cu128, + TorchMode::Cu126 => TorchBackend::Cu126, + TorchMode::Cu125 => TorchBackend::Cu125, + TorchMode::Cu124 => TorchBackend::Cu124, + TorchMode::Cu123 => TorchBackend::Cu123, + TorchMode::Cu122 => TorchBackend::Cu122, + TorchMode::Cu121 => TorchBackend::Cu121, + TorchMode::Cu120 => TorchBackend::Cu120, + TorchMode::Cu118 => TorchBackend::Cu118, + TorchMode::Cu117 => TorchBackend::Cu117, + TorchMode::Cu116 => TorchBackend::Cu116, + TorchMode::Cu115 => TorchBackend::Cu115, + TorchMode::Cu114 => TorchBackend::Cu114, + TorchMode::Cu113 => TorchBackend::Cu113, + TorchMode::Cu112 => TorchBackend::Cu112, + TorchMode::Cu111 => TorchBackend::Cu111, + TorchMode::Cu110 => TorchBackend::Cu110, + TorchMode::Cu102 => TorchBackend::Cu102, + TorchMode::Cu101 => TorchBackend::Cu101, + TorchMode::Cu100 => TorchBackend::Cu100, + TorchMode::Cu92 => TorchBackend::Cu92, + TorchMode::Cu91 => TorchBackend::Cu91, + TorchMode::Cu90 => TorchBackend::Cu90, + TorchMode::Cu80 => TorchBackend::Cu80, + TorchMode::Rocm63 => TorchBackend::Rocm63, + TorchMode::Rocm624 => TorchBackend::Rocm624, + TorchMode::Rocm62 => TorchBackend::Rocm62, + TorchMode::Rocm61 => TorchBackend::Rocm61, + TorchMode::Rocm60 => TorchBackend::Rocm60, + TorchMode::Rocm57 => TorchBackend::Rocm57, + TorchMode::Rocm56 => TorchBackend::Rocm56, + TorchMode::Rocm55 => TorchBackend::Rocm55, + TorchMode::Rocm542 => TorchBackend::Rocm542, + TorchMode::Rocm54 => TorchBackend::Rocm54, + TorchMode::Rocm53 => TorchBackend::Rocm53, + TorchMode::Rocm52 => TorchBackend::Rocm52, + TorchMode::Rocm511 => TorchBackend::Rocm511, + TorchMode::Rocm42 => TorchBackend::Rocm42, + TorchMode::Rocm41 => TorchBackend::Rocm41, + TorchMode::Rocm401 => TorchBackend::Rocm401, + TorchMode::Xpu => TorchBackend::Xpu, + }; + Ok(Self::Backend { backend, source }) } /// Returns `true` if the [`TorchStrategy`] applies to the given [`PackageName`]. pub fn applies_to(&self, package_name: &PackageName) -> bool { - matches!( - package_name.as_str(), - "torch" - | "torch-model-archiver" - | "torch-tb-profiler" - | "torcharrow" - | "torchaudio" - | "torchcsprng" - | "torchdata" - | "torchdistx" - | "torchserve" - | "torchtext" - | "torchvision" - | "triton" - | "pytorch-triton" - | "pytorch-triton-rocm" - | "pytorch-triton-xpu" - ) + let source = match self { + Self::Cuda { source, .. } => *source, + Self::Amd { source, .. } => *source, + Self::Xpu { source, .. } => *source, + Self::Backend { source, .. } => *source, + }; + match source { + TorchSource::PyTorch => { + matches!( + package_name.as_str(), + "torch" + | "torch-model-archiver" + | "torch-tb-profiler" + | "torcharrow" + | "torchaudio" + | "torchcsprng" + | "torchdata" + | "torchdistx" + | "torchserve" + | "torchtext" + | "torchvision" + | "pytorch-triton" + ) + } + TorchSource::Pyx => { + matches!( + package_name.as_str(), + "flash-attn" + | "flash-attn-3" + | "megablocks" + | "natten" + | "deepspeed" + | "vllm" + | "torch" + | "torch-model-archiver" + | "torch-tb-profiler" + | "torcharrow" + | "torchaudio" + | "torchcsprng" + | "torchdata" + | "torchdistx" + | "torchserve" + | "torchtext" + | "torchvision" + | "pytorch-triton" + ) + } + } } /// Returns `true` if the given [`PackageName`] has a system dependency (e.g., CUDA or ROCm). @@ -285,7 +350,13 @@ impl TorchStrategy { pub fn has_system_dependency(&self, package_name: &PackageName) -> bool { matches!( package_name.as_str(), - "torch" + "flash-attn" + | "flash-attn-3" + | "megablocks" + | "natten" + | "deepspeed" + | "vllm" + | "torch" | "torcharrow" | "torchaudio" | "torchcsprng" @@ -299,7 +370,11 @@ impl TorchStrategy { /// Return the appropriate index URLs for the given [`TorchStrategy`]. pub fn index_urls(&self) -> impl Iterator { match self { - Self::Cuda { os, driver_version } => { + Self::Cuda { + os, + driver_version, + source, + } => { // If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA // indexes. // @@ -311,12 +386,12 @@ impl TorchStrategy { .iter() .filter_map(move |(backend, version)| { if driver_version >= version { - Some(backend.index_url()) + Some(backend.index_url(*source)) } else { None } }) - .chain(std::iter::once(TorchBackend::Cpu.index_url())), + .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))), ))) } Os::Windows => Either::Left(Either::Left(Either::Right( @@ -324,12 +399,12 @@ impl TorchStrategy { .iter() .filter_map(move |(backend, version)| { if driver_version >= version { - Some(backend.index_url()) + Some(backend.index_url(*source)) } else { None } }) - .chain(std::iter::once(TorchBackend::Cpu.index_url())), + .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))), ))), Os::Macos { .. } | Os::FreeBsd { .. } @@ -340,26 +415,27 @@ impl TorchStrategy { | Os::Haiku { .. } | Os::Android { .. } | Os::Pyodide { .. } - | Os::Ios { .. } => { - Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url()))) - } + | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once( + TorchBackend::Cpu.index_url(*source), + ))), } } Self::Amd { os, gpu_architecture, + source, } => match os { Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right( LINUX_AMD_GPU_DRIVERS .iter() .filter_map(move |(backend, architecture)| { if gpu_architecture == architecture { - Some(backend.index_url()) + Some(backend.index_url(*source)) } else { None } }) - .chain(std::iter::once(TorchBackend::Cpu.index_url())), + .chain(std::iter::once(TorchBackend::Cpu.index_url(*source))), )), Os::Windows | Os::Macos { .. } @@ -371,13 +447,13 @@ impl TorchStrategy { | Os::Haiku { .. } | Os::Android { .. } | Os::Pyodide { .. } - | Os::Ios { .. } => { - Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url()))) - } + | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once( + TorchBackend::Cpu.index_url(*source), + ))), }, - Self::Xpu { os } => match os { + Self::Xpu { os, source } => match os { Os::Manylinux { .. } => Either::Right(Either::Right(Either::Left( - std::iter::once(TorchBackend::Xpu.index_url()), + std::iter::once(TorchBackend::Xpu.index_url(*source)), ))), Os::Windows | Os::Musllinux { .. } @@ -390,13 +466,13 @@ impl TorchStrategy { | Os::Haiku { .. } | Os::Android { .. } | Os::Pyodide { .. } - | Os::Ios { .. } => { - Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url()))) - } + | Os::Ios { .. } => Either::Right(Either::Left(std::iter::once( + TorchBackend::Cpu.index_url(*source), + ))), }, - Self::Backend(backend) => Either::Right(Either::Right(Either::Right(std::iter::once( - backend.index_url(), - )))), + Self::Backend { backend, source } => Either::Right(Either::Right(Either::Right( + std::iter::once(backend.index_url(*source)), + ))), } } } @@ -451,51 +527,180 @@ pub enum TorchBackend { impl TorchBackend { /// Return the appropriate index URL for the given [`TorchBackend`]. - fn index_url(self) -> &'static IndexUrl { + fn index_url(self, source: TorchSource) -> &'static IndexUrl { match self { - Self::Cpu => &CPU_INDEX_URL, - Self::Cu129 => &CU129_INDEX_URL, - Self::Cu128 => &CU128_INDEX_URL, - Self::Cu126 => &CU126_INDEX_URL, - Self::Cu125 => &CU125_INDEX_URL, - Self::Cu124 => &CU124_INDEX_URL, - Self::Cu123 => &CU123_INDEX_URL, - Self::Cu122 => &CU122_INDEX_URL, - Self::Cu121 => &CU121_INDEX_URL, - Self::Cu120 => &CU120_INDEX_URL, - Self::Cu118 => &CU118_INDEX_URL, - Self::Cu117 => &CU117_INDEX_URL, - Self::Cu116 => &CU116_INDEX_URL, - Self::Cu115 => &CU115_INDEX_URL, - Self::Cu114 => &CU114_INDEX_URL, - Self::Cu113 => &CU113_INDEX_URL, - Self::Cu112 => &CU112_INDEX_URL, - Self::Cu111 => &CU111_INDEX_URL, - Self::Cu110 => &CU110_INDEX_URL, - Self::Cu102 => &CU102_INDEX_URL, - Self::Cu101 => &CU101_INDEX_URL, - Self::Cu100 => &CU100_INDEX_URL, - Self::Cu92 => &CU92_INDEX_URL, - Self::Cu91 => &CU91_INDEX_URL, - Self::Cu90 => &CU90_INDEX_URL, - Self::Cu80 => &CU80_INDEX_URL, - Self::Rocm63 => &ROCM63_INDEX_URL, - Self::Rocm624 => &ROCM624_INDEX_URL, - Self::Rocm62 => &ROCM62_INDEX_URL, - Self::Rocm61 => &ROCM61_INDEX_URL, - Self::Rocm60 => &ROCM60_INDEX_URL, - Self::Rocm57 => &ROCM57_INDEX_URL, - Self::Rocm56 => &ROCM56_INDEX_URL, - Self::Rocm55 => &ROCM55_INDEX_URL, - Self::Rocm542 => &ROCM542_INDEX_URL, - Self::Rocm54 => &ROCM54_INDEX_URL, - Self::Rocm53 => &ROCM53_INDEX_URL, - Self::Rocm52 => &ROCM52_INDEX_URL, - Self::Rocm511 => &ROCM511_INDEX_URL, - Self::Rocm42 => &ROCM42_INDEX_URL, - Self::Rocm41 => &ROCM41_INDEX_URL, - Self::Rocm401 => &ROCM401_INDEX_URL, - Self::Xpu => &XPU_INDEX_URL, + Self::Cpu => match source { + TorchSource::PyTorch => &PYTORCH_CPU_INDEX_URL, + TorchSource::Pyx => &PYX_CPU_INDEX_URL, + }, + Self::Cu129 => match source { + TorchSource::PyTorch => &PYTORCH_CU129_INDEX_URL, + TorchSource::Pyx => &PYX_CU129_INDEX_URL, + }, + Self::Cu128 => match source { + TorchSource::PyTorch => &PYTORCH_CU128_INDEX_URL, + TorchSource::Pyx => &PYX_CU128_INDEX_URL, + }, + Self::Cu126 => match source { + TorchSource::PyTorch => &PYTORCH_CU126_INDEX_URL, + TorchSource::Pyx => &PYX_CU126_INDEX_URL, + }, + Self::Cu125 => match source { + TorchSource::PyTorch => &PYTORCH_CU125_INDEX_URL, + TorchSource::Pyx => &PYX_CU125_INDEX_URL, + }, + Self::Cu124 => match source { + TorchSource::PyTorch => &PYTORCH_CU124_INDEX_URL, + TorchSource::Pyx => &PYX_CU124_INDEX_URL, + }, + Self::Cu123 => match source { + TorchSource::PyTorch => &PYTORCH_CU123_INDEX_URL, + TorchSource::Pyx => &PYX_CU123_INDEX_URL, + }, + Self::Cu122 => match source { + TorchSource::PyTorch => &PYTORCH_CU122_INDEX_URL, + TorchSource::Pyx => &PYX_CU122_INDEX_URL, + }, + Self::Cu121 => match source { + TorchSource::PyTorch => &PYTORCH_CU121_INDEX_URL, + TorchSource::Pyx => &PYX_CU121_INDEX_URL, + }, + Self::Cu120 => match source { + TorchSource::PyTorch => &PYTORCH_CU120_INDEX_URL, + TorchSource::Pyx => &PYX_CU120_INDEX_URL, + }, + Self::Cu118 => match source { + TorchSource::PyTorch => &PYTORCH_CU118_INDEX_URL, + TorchSource::Pyx => &PYX_CU118_INDEX_URL, + }, + Self::Cu117 => match source { + TorchSource::PyTorch => &PYTORCH_CU117_INDEX_URL, + TorchSource::Pyx => &PYX_CU117_INDEX_URL, + }, + Self::Cu116 => match source { + TorchSource::PyTorch => &PYTORCH_CU116_INDEX_URL, + TorchSource::Pyx => &PYX_CU116_INDEX_URL, + }, + Self::Cu115 => match source { + TorchSource::PyTorch => &PYTORCH_CU115_INDEX_URL, + TorchSource::Pyx => &PYX_CU115_INDEX_URL, + }, + Self::Cu114 => match source { + TorchSource::PyTorch => &PYTORCH_CU114_INDEX_URL, + TorchSource::Pyx => &PYX_CU114_INDEX_URL, + }, + Self::Cu113 => match source { + TorchSource::PyTorch => &PYTORCH_CU113_INDEX_URL, + TorchSource::Pyx => &PYX_CU113_INDEX_URL, + }, + Self::Cu112 => match source { + TorchSource::PyTorch => &PYTORCH_CU112_INDEX_URL, + TorchSource::Pyx => &PYX_CU112_INDEX_URL, + }, + Self::Cu111 => match source { + TorchSource::PyTorch => &PYTORCH_CU111_INDEX_URL, + TorchSource::Pyx => &PYX_CU111_INDEX_URL, + }, + Self::Cu110 => match source { + TorchSource::PyTorch => &PYTORCH_CU110_INDEX_URL, + TorchSource::Pyx => &PYX_CU110_INDEX_URL, + }, + Self::Cu102 => match source { + TorchSource::PyTorch => &PYTORCH_CU102_INDEX_URL, + TorchSource::Pyx => &PYX_CU102_INDEX_URL, + }, + Self::Cu101 => match source { + TorchSource::PyTorch => &PYTORCH_CU101_INDEX_URL, + TorchSource::Pyx => &PYX_CU101_INDEX_URL, + }, + Self::Cu100 => match source { + TorchSource::PyTorch => &PYTORCH_CU100_INDEX_URL, + TorchSource::Pyx => &PYX_CU100_INDEX_URL, + }, + Self::Cu92 => match source { + TorchSource::PyTorch => &PYTORCH_CU92_INDEX_URL, + TorchSource::Pyx => &PYX_CU92_INDEX_URL, + }, + Self::Cu91 => match source { + TorchSource::PyTorch => &PYTORCH_CU91_INDEX_URL, + TorchSource::Pyx => &PYX_CU91_INDEX_URL, + }, + Self::Cu90 => match source { + TorchSource::PyTorch => &PYTORCH_CU90_INDEX_URL, + TorchSource::Pyx => &PYX_CU90_INDEX_URL, + }, + Self::Cu80 => match source { + TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL, + TorchSource::Pyx => &PYX_CU80_INDEX_URL, + }, + Self::Rocm63 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM63_INDEX_URL, + }, + Self::Rocm624 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM624_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM624_INDEX_URL, + }, + Self::Rocm62 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM62_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM62_INDEX_URL, + }, + Self::Rocm61 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM61_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM61_INDEX_URL, + }, + Self::Rocm60 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM60_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM60_INDEX_URL, + }, + Self::Rocm57 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM57_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM57_INDEX_URL, + }, + Self::Rocm56 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM56_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM56_INDEX_URL, + }, + Self::Rocm55 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM55_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM55_INDEX_URL, + }, + Self::Rocm542 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM542_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM542_INDEX_URL, + }, + Self::Rocm54 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM54_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM54_INDEX_URL, + }, + Self::Rocm53 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM53_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM53_INDEX_URL, + }, + Self::Rocm52 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM52_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM52_INDEX_URL, + }, + Self::Rocm511 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM511_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM511_INDEX_URL, + }, + Self::Rocm42 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM42_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM42_INDEX_URL, + }, + Self::Rocm41 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM41_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM41_INDEX_URL, + }, + Self::Rocm401 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM401_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM401_INDEX_URL, + }, + Self::Xpu => match source { + TorchSource::PyTorch => &PYTORCH_XPU_INDEX_URL, + TorchSource::Pyx => &PYX_XPU_INDEX_URL, + }, } } @@ -508,6 +713,16 @@ impl TorchBackend { return None; } path_segments.next()? + } else if index.host_str() == API_BASE_URL.strip_prefix("https://") { + // E.g., `https://api.pyx.dev/simple/astral-sh/cu124` + let mut path_segments = index.path_segments()?; + if path_segments.next() != Some("simple") { + return None; + } + if path_segments.next() != Some("astral-sh") { + return None; + } + path_segments.next()? } else { return None; }; @@ -619,7 +834,6 @@ impl FromStr for TorchBackend { fn from_str(s: &str) -> Result { match s { "cpu" => Ok(Self::Cpu), - "cu129" => Ok(Self::Cu129), "cu128" => Ok(Self::Cu128), "cu126" => Ok(Self::Cu126), "cu125" => Ok(Self::Cu125), @@ -669,11 +883,10 @@ impl FromStr for TorchBackend { /// Linux CUDA driver versions and the corresponding CUDA versions. /// /// See: -static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 25]> = LazyLock::new(|| { +static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| { [ // Table 2 from // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html - (TorchBackend::Cu129, Version::new([525, 60, 13])), (TorchBackend::Cu128, Version::new([525, 60, 13])), (TorchBackend::Cu126, Version::new([525, 60, 13])), (TorchBackend::Cu125, Version::new([525, 60, 13])), @@ -708,11 +921,10 @@ static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 25]> = LazyLock::n /// Windows CUDA driver versions and the corresponding CUDA versions. /// /// See: -static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 25]> = LazyLock::new(|| { +static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| { [ // Table 2 from // https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html - (TorchBackend::Cu129, Version::new([528, 33])), (TorchBackend::Cu128, Version::new([528, 33])), (TorchBackend::Cu126, Version::new([528, 33])), (TorchBackend::Cu125, Version::new([528, 33])), @@ -811,89 +1023,267 @@ static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> ] }); -static CPU_INDEX_URL: LazyLock = +static PYTORCH_CPU_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap()); -static CU129_INDEX_URL: LazyLock = +static PYTORCH_CU129_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu129").unwrap()); -static CU128_INDEX_URL: LazyLock = +static PYTORCH_CU128_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu128").unwrap()); -static CU126_INDEX_URL: LazyLock = +static PYTORCH_CU126_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap()); -static CU125_INDEX_URL: LazyLock = +static PYTORCH_CU125_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu125").unwrap()); -static CU124_INDEX_URL: LazyLock = +static PYTORCH_CU124_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu124").unwrap()); -static CU123_INDEX_URL: LazyLock = +static PYTORCH_CU123_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu123").unwrap()); -static CU122_INDEX_URL: LazyLock = +static PYTORCH_CU122_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu122").unwrap()); -static CU121_INDEX_URL: LazyLock = +static PYTORCH_CU121_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu121").unwrap()); -static CU120_INDEX_URL: LazyLock = +static PYTORCH_CU120_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu120").unwrap()); -static CU118_INDEX_URL: LazyLock = +static PYTORCH_CU118_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap()); -static CU117_INDEX_URL: LazyLock = +static PYTORCH_CU117_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu117").unwrap()); -static CU116_INDEX_URL: LazyLock = +static PYTORCH_CU116_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu116").unwrap()); -static CU115_INDEX_URL: LazyLock = +static PYTORCH_CU115_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu115").unwrap()); -static CU114_INDEX_URL: LazyLock = +static PYTORCH_CU114_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu114").unwrap()); -static CU113_INDEX_URL: LazyLock = +static PYTORCH_CU113_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu113").unwrap()); -static CU112_INDEX_URL: LazyLock = +static PYTORCH_CU112_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu112").unwrap()); -static CU111_INDEX_URL: LazyLock = +static PYTORCH_CU111_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu111").unwrap()); -static CU110_INDEX_URL: LazyLock = +static PYTORCH_CU110_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu110").unwrap()); -static CU102_INDEX_URL: LazyLock = +static PYTORCH_CU102_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu102").unwrap()); -static CU101_INDEX_URL: LazyLock = +static PYTORCH_CU101_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu101").unwrap()); -static CU100_INDEX_URL: LazyLock = +static PYTORCH_CU100_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu100").unwrap()); -static CU92_INDEX_URL: LazyLock = +static PYTORCH_CU92_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu92").unwrap()); -static CU91_INDEX_URL: LazyLock = +static PYTORCH_CU91_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu91").unwrap()); -static CU90_INDEX_URL: LazyLock = +static PYTORCH_CU90_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap()); -static CU80_INDEX_URL: LazyLock = +static PYTORCH_CU80_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap()); -static ROCM63_INDEX_URL: LazyLock = +static PYTORCH_ROCM63_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap()); -static ROCM624_INDEX_URL: LazyLock = +static PYTORCH_ROCM624_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2.4").unwrap()); -static ROCM62_INDEX_URL: LazyLock = +static PYTORCH_ROCM62_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2").unwrap()); -static ROCM61_INDEX_URL: LazyLock = +static PYTORCH_ROCM61_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.1").unwrap()); -static ROCM60_INDEX_URL: LazyLock = +static PYTORCH_ROCM60_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.0").unwrap()); -static ROCM57_INDEX_URL: LazyLock = +static PYTORCH_ROCM57_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.7").unwrap()); -static ROCM56_INDEX_URL: LazyLock = +static PYTORCH_ROCM56_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.6").unwrap()); -static ROCM55_INDEX_URL: LazyLock = +static PYTORCH_ROCM55_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.5").unwrap()); -static ROCM542_INDEX_URL: LazyLock = +static PYTORCH_ROCM542_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4.2").unwrap()); -static ROCM54_INDEX_URL: LazyLock = +static PYTORCH_ROCM54_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4").unwrap()); -static ROCM53_INDEX_URL: LazyLock = +static PYTORCH_ROCM53_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.3").unwrap()); -static ROCM52_INDEX_URL: LazyLock = +static PYTORCH_ROCM52_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.2").unwrap()); -static ROCM511_INDEX_URL: LazyLock = +static PYTORCH_ROCM511_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.1.1").unwrap()); -static ROCM42_INDEX_URL: LazyLock = +static PYTORCH_ROCM42_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.2").unwrap()); -static ROCM41_INDEX_URL: LazyLock = +static PYTORCH_ROCM41_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.1").unwrap()); -static ROCM401_INDEX_URL: LazyLock = +static PYTORCH_ROCM401_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.0.1").unwrap()); -static XPU_INDEX_URL: LazyLock = +static PYTORCH_XPU_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/xpu").unwrap()); + +static API_BASE_URL: LazyLock> = LazyLock::new(|| { + std::env::var(EnvVars::PYX_API_URL) + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed("https://api.pyx.dev")) +}); +static PYX_CPU_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cpu")).unwrap() +}); +static PYX_CU129_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu129")).unwrap() +}); +static PYX_CU128_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu128")).unwrap() +}); +static PYX_CU126_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu126")).unwrap() +}); +static PYX_CU125_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu125")).unwrap() +}); +static PYX_CU124_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu124")).unwrap() +}); +static PYX_CU123_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu123")).unwrap() +}); +static PYX_CU122_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu122")).unwrap() +}); +static PYX_CU121_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu121")).unwrap() +}); +static PYX_CU120_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu120")).unwrap() +}); +static PYX_CU118_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu118")).unwrap() +}); +static PYX_CU117_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu117")).unwrap() +}); +static PYX_CU116_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu116")).unwrap() +}); +static PYX_CU115_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu115")).unwrap() +}); +static PYX_CU114_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu114")).unwrap() +}); +static PYX_CU113_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu113")).unwrap() +}); +static PYX_CU112_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu112")).unwrap() +}); +static PYX_CU111_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu111")).unwrap() +}); +static PYX_CU110_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu110")).unwrap() +}); +static PYX_CU102_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu102")).unwrap() +}); +static PYX_CU101_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu101")).unwrap() +}); +static PYX_CU100_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu100")).unwrap() +}); +static PYX_CU92_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu92")).unwrap() +}); +static PYX_CU91_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu91")).unwrap() +}); +static PYX_CU90_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu90")).unwrap() +}); +static PYX_CU80_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap() +}); +static PYX_ROCM63_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap() +}); +static PYX_ROCM624_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2.4")).unwrap() +}); +static PYX_ROCM62_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2")).unwrap() +}); +static PYX_ROCM61_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.1")).unwrap() +}); +static PYX_ROCM60_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.0")).unwrap() +}); +static PYX_ROCM57_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.7")).unwrap() +}); +static PYX_ROCM56_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.6")).unwrap() +}); +static PYX_ROCM55_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.5")).unwrap() +}); +static PYX_ROCM542_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4.2")).unwrap() +}); +static PYX_ROCM54_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4")).unwrap() +}); +static PYX_ROCM53_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.3")).unwrap() +}); +static PYX_ROCM52_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.2")).unwrap() +}); +static PYX_ROCM511_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.1.1")).unwrap() +}); +static PYX_ROCM42_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.2")).unwrap() +}); +static PYX_ROCM41_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.1")).unwrap() +}); +static PYX_ROCM401_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.0.1")).unwrap() +}); +static PYX_XPU_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/xpu")).unwrap() +}); diff --git a/crates/uv/src/commands/pip/compile.rs b/crates/uv/src/commands/pip/compile.rs index ad054d1b6..aa89884ea 100644 --- a/crates/uv/src/commands/pip/compile.rs +++ b/crates/uv/src/commands/pip/compile.rs @@ -45,7 +45,7 @@ use uv_resolver::{ ResolverEnvironment, }; use uv_static::EnvVars; -use uv_torch::{TorchMode, TorchStrategy}; +use uv_torch::{TorchMode, TorchSource, TorchStrategy}; use uv_types::{EmptyInstalledPackages, HashStrategy}; use uv_warnings::{warn_user, warn_user_once}; use uv_workspace::WorkspaceCache; @@ -402,8 +402,16 @@ pub(crate) async fn pip_compile( // Determine the PyTorch backend. let torch_backend = torch_backend .map(|mode| { + let source = if uv_auth::PyxTokenStore::from_settings() + .is_ok_and(|store| store.has_credentials()) + { + TorchSource::Pyx + } else { + TorchSource::default() + }; TorchStrategy::from_mode( mode, + source, python_platform .map(TargetTriple::platform) .as_ref() diff --git a/crates/uv/src/commands/pip/install.rs b/crates/uv/src/commands/pip/install.rs index 67427f698..fa3dfb53f 100644 --- a/crates/uv/src/commands/pip/install.rs +++ b/crates/uv/src/commands/pip/install.rs @@ -35,7 +35,7 @@ use uv_resolver::{ DependencyMode, ExcludeNewer, FlatIndex, OptionsBuilder, PrereleaseMode, PylockToml, PythonRequirement, ResolutionMode, ResolverEnvironment, }; -use uv_torch::{TorchMode, TorchStrategy}; +use uv_torch::{TorchMode, TorchSource, TorchStrategy}; use uv_types::HashStrategy; use uv_warnings::{warn_user, warn_user_once}; use uv_workspace::WorkspaceCache; @@ -365,8 +365,16 @@ pub(crate) async fn pip_install( // Determine the PyTorch backend. let torch_backend = torch_backend .map(|mode| { + let source = if uv_auth::PyxTokenStore::from_settings() + .is_ok_and(|store| store.has_credentials()) + { + TorchSource::Pyx + } else { + TorchSource::default() + }; TorchStrategy::from_mode( mode, + source, python_platform .map(TargetTriple::platform) .as_ref() diff --git a/crates/uv/src/commands/pip/sync.rs b/crates/uv/src/commands/pip/sync.rs index 4dc8c529c..23406972d 100644 --- a/crates/uv/src/commands/pip/sync.rs +++ b/crates/uv/src/commands/pip/sync.rs @@ -33,7 +33,7 @@ use uv_resolver::{ DependencyMode, ExcludeNewer, FlatIndex, OptionsBuilder, PrereleaseMode, PylockToml, PythonRequirement, ResolutionMode, ResolverEnvironment, }; -use uv_torch::{TorchMode, TorchStrategy}; +use uv_torch::{TorchMode, TorchSource, TorchStrategy}; use uv_types::HashStrategy; use uv_warnings::{warn_user, warn_user_once}; use uv_workspace::WorkspaceCache; @@ -289,8 +289,16 @@ pub(crate) async fn pip_sync( // Determine the PyTorch backend. let torch_backend = torch_backend .map(|mode| { + let source = if uv_auth::PyxTokenStore::from_settings() + .is_ok_and(|store| store.has_credentials()) + { + TorchSource::Pyx + } else { + TorchSource::default() + }; TorchStrategy::from_mode( mode, + source, python_platform .map(TargetTriple::platform) .as_ref()