Add pyx as a supported PyTorch index URL (#15769)

## Summary

If the user explicitly authenticated to pyx, then we attempt to use the
pyx PyTorch URLs; otherwise, we stick to `download.pytorch.org` as the
default.
This commit is contained in:
Charlie Marsh 2025-09-10 15:38:00 -04:00 committed by GitHub
parent 0d174b79e2
commit b195d523d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 608 additions and 194 deletions

View File

@ -36,6 +36,7 @@
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//! ```
use std::borrow::Cow;
use std::str::FromStr;
use std::sync::LazyLock;
@ -46,6 +47,7 @@ use uv_distribution_types::IndexUrl;
use uv_normalize::PackageName;
use uv_pep440::Version;
use uv_platform_tags::Os;
use uv_static::EnvVars;
use crate::{Accelerator, AcceleratorError, AmdGpuArchitecture};
@ -177,86 +179,127 @@ pub enum TorchMode {
Xpu,
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
pub enum TorchSource {
/// Download PyTorch builds from the official PyTorch index.
#[default]
PyTorch,
/// Download PyTorch builds from the pyx index.
Pyx,
}
/// The strategy to use when determining the appropriate PyTorch index.
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TorchStrategy {
/// Select the appropriate PyTorch index based on the operating system and CUDA driver version (e.g., `550.144.03`).
Cuda { os: Os, driver_version: Version },
Cuda {
os: Os,
driver_version: Version,
source: TorchSource,
},
/// Select the appropriate PyTorch index based on the operating system and AMD GPU architecture (e.g., `gfx1100`).
Amd {
os: Os,
gpu_architecture: AmdGpuArchitecture,
source: TorchSource,
},
/// Select the appropriate PyTorch index based on the operating system and Intel GPU presence.
Xpu { os: Os },
Xpu { os: Os, source: TorchSource },
/// Use the specified PyTorch index.
Backend(TorchBackend),
Backend {
backend: TorchBackend,
source: TorchSource,
},
}
impl TorchStrategy {
/// Determine the [`TorchStrategy`] from the given [`TorchMode`], [`Os`], and [`Accelerator`].
pub fn from_mode(mode: TorchMode, os: &Os) -> Result<Self, AcceleratorError> {
match mode {
pub fn from_mode(
mode: TorchMode,
source: TorchSource,
os: &Os,
) -> Result<Self, AcceleratorError> {
let backend = match mode {
TorchMode::Auto => match Accelerator::detect()? {
Some(Accelerator::Cuda { driver_version }) => Ok(Self::Cuda {
Some(Accelerator::Cuda { driver_version }) => {
return Ok(Self::Cuda {
os: os.clone(),
driver_version: driver_version.clone(),
}),
Some(Accelerator::Amd { gpu_architecture }) => Ok(Self::Amd {
source,
});
}
Some(Accelerator::Amd { gpu_architecture }) => {
return Ok(Self::Amd {
os: os.clone(),
gpu_architecture,
}),
Some(Accelerator::Xpu) => Ok(Self::Xpu { os: os.clone() }),
None => Ok(Self::Backend(TorchBackend::Cpu)),
},
TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)),
TorchMode::Cu129 => Ok(Self::Backend(TorchBackend::Cu129)),
TorchMode::Cu128 => Ok(Self::Backend(TorchBackend::Cu128)),
TorchMode::Cu126 => Ok(Self::Backend(TorchBackend::Cu126)),
TorchMode::Cu125 => Ok(Self::Backend(TorchBackend::Cu125)),
TorchMode::Cu124 => Ok(Self::Backend(TorchBackend::Cu124)),
TorchMode::Cu123 => Ok(Self::Backend(TorchBackend::Cu123)),
TorchMode::Cu122 => Ok(Self::Backend(TorchBackend::Cu122)),
TorchMode::Cu121 => Ok(Self::Backend(TorchBackend::Cu121)),
TorchMode::Cu120 => Ok(Self::Backend(TorchBackend::Cu120)),
TorchMode::Cu118 => Ok(Self::Backend(TorchBackend::Cu118)),
TorchMode::Cu117 => Ok(Self::Backend(TorchBackend::Cu117)),
TorchMode::Cu116 => Ok(Self::Backend(TorchBackend::Cu116)),
TorchMode::Cu115 => Ok(Self::Backend(TorchBackend::Cu115)),
TorchMode::Cu114 => Ok(Self::Backend(TorchBackend::Cu114)),
TorchMode::Cu113 => Ok(Self::Backend(TorchBackend::Cu113)),
TorchMode::Cu112 => Ok(Self::Backend(TorchBackend::Cu112)),
TorchMode::Cu111 => Ok(Self::Backend(TorchBackend::Cu111)),
TorchMode::Cu110 => Ok(Self::Backend(TorchBackend::Cu110)),
TorchMode::Cu102 => Ok(Self::Backend(TorchBackend::Cu102)),
TorchMode::Cu101 => Ok(Self::Backend(TorchBackend::Cu101)),
TorchMode::Cu100 => Ok(Self::Backend(TorchBackend::Cu100)),
TorchMode::Cu92 => Ok(Self::Backend(TorchBackend::Cu92)),
TorchMode::Cu91 => Ok(Self::Backend(TorchBackend::Cu91)),
TorchMode::Cu90 => Ok(Self::Backend(TorchBackend::Cu90)),
TorchMode::Cu80 => Ok(Self::Backend(TorchBackend::Cu80)),
TorchMode::Rocm63 => Ok(Self::Backend(TorchBackend::Rocm63)),
TorchMode::Rocm624 => Ok(Self::Backend(TorchBackend::Rocm624)),
TorchMode::Rocm62 => Ok(Self::Backend(TorchBackend::Rocm62)),
TorchMode::Rocm61 => Ok(Self::Backend(TorchBackend::Rocm61)),
TorchMode::Rocm60 => Ok(Self::Backend(TorchBackend::Rocm60)),
TorchMode::Rocm57 => Ok(Self::Backend(TorchBackend::Rocm57)),
TorchMode::Rocm56 => Ok(Self::Backend(TorchBackend::Rocm56)),
TorchMode::Rocm55 => Ok(Self::Backend(TorchBackend::Rocm55)),
TorchMode::Rocm542 => Ok(Self::Backend(TorchBackend::Rocm542)),
TorchMode::Rocm54 => Ok(Self::Backend(TorchBackend::Rocm54)),
TorchMode::Rocm53 => Ok(Self::Backend(TorchBackend::Rocm53)),
TorchMode::Rocm52 => Ok(Self::Backend(TorchBackend::Rocm52)),
TorchMode::Rocm511 => Ok(Self::Backend(TorchBackend::Rocm511)),
TorchMode::Rocm42 => Ok(Self::Backend(TorchBackend::Rocm42)),
TorchMode::Rocm41 => Ok(Self::Backend(TorchBackend::Rocm41)),
TorchMode::Rocm401 => Ok(Self::Backend(TorchBackend::Rocm401)),
TorchMode::Xpu => Ok(Self::Backend(TorchBackend::Xpu)),
source,
});
}
Some(Accelerator::Xpu) => {
return Ok(Self::Xpu {
os: os.clone(),
source,
});
}
None => TorchBackend::Cpu,
},
TorchMode::Cpu => TorchBackend::Cpu,
TorchMode::Cu129 => TorchBackend::Cu129,
TorchMode::Cu128 => TorchBackend::Cu128,
TorchMode::Cu126 => TorchBackend::Cu126,
TorchMode::Cu125 => TorchBackend::Cu125,
TorchMode::Cu124 => TorchBackend::Cu124,
TorchMode::Cu123 => TorchBackend::Cu123,
TorchMode::Cu122 => TorchBackend::Cu122,
TorchMode::Cu121 => TorchBackend::Cu121,
TorchMode::Cu120 => TorchBackend::Cu120,
TorchMode::Cu118 => TorchBackend::Cu118,
TorchMode::Cu117 => TorchBackend::Cu117,
TorchMode::Cu116 => TorchBackend::Cu116,
TorchMode::Cu115 => TorchBackend::Cu115,
TorchMode::Cu114 => TorchBackend::Cu114,
TorchMode::Cu113 => TorchBackend::Cu113,
TorchMode::Cu112 => TorchBackend::Cu112,
TorchMode::Cu111 => TorchBackend::Cu111,
TorchMode::Cu110 => TorchBackend::Cu110,
TorchMode::Cu102 => TorchBackend::Cu102,
TorchMode::Cu101 => TorchBackend::Cu101,
TorchMode::Cu100 => TorchBackend::Cu100,
TorchMode::Cu92 => TorchBackend::Cu92,
TorchMode::Cu91 => TorchBackend::Cu91,
TorchMode::Cu90 => TorchBackend::Cu90,
TorchMode::Cu80 => TorchBackend::Cu80,
TorchMode::Rocm63 => TorchBackend::Rocm63,
TorchMode::Rocm624 => TorchBackend::Rocm624,
TorchMode::Rocm62 => TorchBackend::Rocm62,
TorchMode::Rocm61 => TorchBackend::Rocm61,
TorchMode::Rocm60 => TorchBackend::Rocm60,
TorchMode::Rocm57 => TorchBackend::Rocm57,
TorchMode::Rocm56 => TorchBackend::Rocm56,
TorchMode::Rocm55 => TorchBackend::Rocm55,
TorchMode::Rocm542 => TorchBackend::Rocm542,
TorchMode::Rocm54 => TorchBackend::Rocm54,
TorchMode::Rocm53 => TorchBackend::Rocm53,
TorchMode::Rocm52 => TorchBackend::Rocm52,
TorchMode::Rocm511 => TorchBackend::Rocm511,
TorchMode::Rocm42 => TorchBackend::Rocm42,
TorchMode::Rocm41 => TorchBackend::Rocm41,
TorchMode::Rocm401 => TorchBackend::Rocm401,
TorchMode::Xpu => TorchBackend::Xpu,
};
Ok(Self::Backend { backend, source })
}
/// Returns `true` if the [`TorchStrategy`] applies to the given [`PackageName`].
pub fn applies_to(&self, package_name: &PackageName) -> bool {
let source = match self {
Self::Cuda { source, .. } => *source,
Self::Amd { source, .. } => *source,
Self::Xpu { source, .. } => *source,
Self::Backend { source, .. } => *source,
};
match source {
TorchSource::PyTorch => {
matches!(
package_name.as_str(),
"torch"
@ -270,12 +313,34 @@ impl TorchStrategy {
| "torchserve"
| "torchtext"
| "torchvision"
| "triton"
| "pytorch-triton"
| "pytorch-triton-rocm"
| "pytorch-triton-xpu"
)
}
TorchSource::Pyx => {
matches!(
package_name.as_str(),
"flash-attn"
| "flash-attn-3"
| "megablocks"
| "natten"
| "deepspeed"
| "vllm"
| "torch"
| "torch-model-archiver"
| "torch-tb-profiler"
| "torcharrow"
| "torchaudio"
| "torchcsprng"
| "torchdata"
| "torchdistx"
| "torchserve"
| "torchtext"
| "torchvision"
| "pytorch-triton"
)
}
}
}
/// Returns `true` if the given [`PackageName`] has a system dependency (e.g., CUDA or ROCm).
///
@ -285,7 +350,13 @@ impl TorchStrategy {
pub fn has_system_dependency(&self, package_name: &PackageName) -> bool {
matches!(
package_name.as_str(),
"torch"
"flash-attn"
| "flash-attn-3"
| "megablocks"
| "natten"
| "deepspeed"
| "vllm"
| "torch"
| "torcharrow"
| "torchaudio"
| "torchcsprng"
@ -299,7 +370,11 @@ impl TorchStrategy {
/// Return the appropriate index URLs for the given [`TorchStrategy`].
pub fn index_urls(&self) -> impl Iterator<Item = &IndexUrl> {
match self {
Self::Cuda { os, driver_version } => {
Self::Cuda {
os,
driver_version,
source,
} => {
// If this is a GPU-enabled package, and CUDA drivers are installed, use PyTorch's CUDA
// indexes.
//
@ -311,12 +386,12 @@ impl TorchStrategy {
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
Some(backend.index_url(*source))
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
.chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
)))
}
Os::Windows => Either::Left(Either::Left(Either::Right(
@ -324,12 +399,12 @@ impl TorchStrategy {
.iter()
.filter_map(move |(backend, version)| {
if driver_version >= version {
Some(backend.index_url())
Some(backend.index_url(*source))
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
.chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
))),
Os::Macos { .. }
| Os::FreeBsd { .. }
@ -340,26 +415,27 @@ impl TorchStrategy {
| Os::Haiku { .. }
| Os::Android { .. }
| Os::Pyodide { .. }
| Os::Ios { .. } => {
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
}
| Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
TorchBackend::Cpu.index_url(*source),
))),
}
}
Self::Amd {
os,
gpu_architecture,
source,
} => match os {
Os::Manylinux { .. } | Os::Musllinux { .. } => Either::Left(Either::Right(
LINUX_AMD_GPU_DRIVERS
.iter()
.filter_map(move |(backend, architecture)| {
if gpu_architecture == architecture {
Some(backend.index_url())
Some(backend.index_url(*source))
} else {
None
}
})
.chain(std::iter::once(TorchBackend::Cpu.index_url())),
.chain(std::iter::once(TorchBackend::Cpu.index_url(*source))),
)),
Os::Windows
| Os::Macos { .. }
@ -371,13 +447,13 @@ impl TorchStrategy {
| Os::Haiku { .. }
| Os::Android { .. }
| Os::Pyodide { .. }
| Os::Ios { .. } => {
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
}
| Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
TorchBackend::Cpu.index_url(*source),
))),
},
Self::Xpu { os } => match os {
Self::Xpu { os, source } => match os {
Os::Manylinux { .. } => Either::Right(Either::Right(Either::Left(
std::iter::once(TorchBackend::Xpu.index_url()),
std::iter::once(TorchBackend::Xpu.index_url(*source)),
))),
Os::Windows
| Os::Musllinux { .. }
@ -390,13 +466,13 @@ impl TorchStrategy {
| Os::Haiku { .. }
| Os::Android { .. }
| Os::Pyodide { .. }
| Os::Ios { .. } => {
Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url())))
}
| Os::Ios { .. } => Either::Right(Either::Left(std::iter::once(
TorchBackend::Cpu.index_url(*source),
))),
},
Self::Backend(backend) => Either::Right(Either::Right(Either::Right(std::iter::once(
backend.index_url(),
)))),
Self::Backend { backend, source } => Either::Right(Either::Right(Either::Right(
std::iter::once(backend.index_url(*source)),
))),
}
}
}
@ -451,51 +527,180 @@ pub enum TorchBackend {
impl TorchBackend {
/// Return the appropriate index URL for the given [`TorchBackend`].
fn index_url(self) -> &'static IndexUrl {
fn index_url(self, source: TorchSource) -> &'static IndexUrl {
match self {
Self::Cpu => &CPU_INDEX_URL,
Self::Cu129 => &CU129_INDEX_URL,
Self::Cu128 => &CU128_INDEX_URL,
Self::Cu126 => &CU126_INDEX_URL,
Self::Cu125 => &CU125_INDEX_URL,
Self::Cu124 => &CU124_INDEX_URL,
Self::Cu123 => &CU123_INDEX_URL,
Self::Cu122 => &CU122_INDEX_URL,
Self::Cu121 => &CU121_INDEX_URL,
Self::Cu120 => &CU120_INDEX_URL,
Self::Cu118 => &CU118_INDEX_URL,
Self::Cu117 => &CU117_INDEX_URL,
Self::Cu116 => &CU116_INDEX_URL,
Self::Cu115 => &CU115_INDEX_URL,
Self::Cu114 => &CU114_INDEX_URL,
Self::Cu113 => &CU113_INDEX_URL,
Self::Cu112 => &CU112_INDEX_URL,
Self::Cu111 => &CU111_INDEX_URL,
Self::Cu110 => &CU110_INDEX_URL,
Self::Cu102 => &CU102_INDEX_URL,
Self::Cu101 => &CU101_INDEX_URL,
Self::Cu100 => &CU100_INDEX_URL,
Self::Cu92 => &CU92_INDEX_URL,
Self::Cu91 => &CU91_INDEX_URL,
Self::Cu90 => &CU90_INDEX_URL,
Self::Cu80 => &CU80_INDEX_URL,
Self::Rocm63 => &ROCM63_INDEX_URL,
Self::Rocm624 => &ROCM624_INDEX_URL,
Self::Rocm62 => &ROCM62_INDEX_URL,
Self::Rocm61 => &ROCM61_INDEX_URL,
Self::Rocm60 => &ROCM60_INDEX_URL,
Self::Rocm57 => &ROCM57_INDEX_URL,
Self::Rocm56 => &ROCM56_INDEX_URL,
Self::Rocm55 => &ROCM55_INDEX_URL,
Self::Rocm542 => &ROCM542_INDEX_URL,
Self::Rocm54 => &ROCM54_INDEX_URL,
Self::Rocm53 => &ROCM53_INDEX_URL,
Self::Rocm52 => &ROCM52_INDEX_URL,
Self::Rocm511 => &ROCM511_INDEX_URL,
Self::Rocm42 => &ROCM42_INDEX_URL,
Self::Rocm41 => &ROCM41_INDEX_URL,
Self::Rocm401 => &ROCM401_INDEX_URL,
Self::Xpu => &XPU_INDEX_URL,
Self::Cpu => match source {
TorchSource::PyTorch => &PYTORCH_CPU_INDEX_URL,
TorchSource::Pyx => &PYX_CPU_INDEX_URL,
},
Self::Cu129 => match source {
TorchSource::PyTorch => &PYTORCH_CU129_INDEX_URL,
TorchSource::Pyx => &PYX_CU129_INDEX_URL,
},
Self::Cu128 => match source {
TorchSource::PyTorch => &PYTORCH_CU128_INDEX_URL,
TorchSource::Pyx => &PYX_CU128_INDEX_URL,
},
Self::Cu126 => match source {
TorchSource::PyTorch => &PYTORCH_CU126_INDEX_URL,
TorchSource::Pyx => &PYX_CU126_INDEX_URL,
},
Self::Cu125 => match source {
TorchSource::PyTorch => &PYTORCH_CU125_INDEX_URL,
TorchSource::Pyx => &PYX_CU125_INDEX_URL,
},
Self::Cu124 => match source {
TorchSource::PyTorch => &PYTORCH_CU124_INDEX_URL,
TorchSource::Pyx => &PYX_CU124_INDEX_URL,
},
Self::Cu123 => match source {
TorchSource::PyTorch => &PYTORCH_CU123_INDEX_URL,
TorchSource::Pyx => &PYX_CU123_INDEX_URL,
},
Self::Cu122 => match source {
TorchSource::PyTorch => &PYTORCH_CU122_INDEX_URL,
TorchSource::Pyx => &PYX_CU122_INDEX_URL,
},
Self::Cu121 => match source {
TorchSource::PyTorch => &PYTORCH_CU121_INDEX_URL,
TorchSource::Pyx => &PYX_CU121_INDEX_URL,
},
Self::Cu120 => match source {
TorchSource::PyTorch => &PYTORCH_CU120_INDEX_URL,
TorchSource::Pyx => &PYX_CU120_INDEX_URL,
},
Self::Cu118 => match source {
TorchSource::PyTorch => &PYTORCH_CU118_INDEX_URL,
TorchSource::Pyx => &PYX_CU118_INDEX_URL,
},
Self::Cu117 => match source {
TorchSource::PyTorch => &PYTORCH_CU117_INDEX_URL,
TorchSource::Pyx => &PYX_CU117_INDEX_URL,
},
Self::Cu116 => match source {
TorchSource::PyTorch => &PYTORCH_CU116_INDEX_URL,
TorchSource::Pyx => &PYX_CU116_INDEX_URL,
},
Self::Cu115 => match source {
TorchSource::PyTorch => &PYTORCH_CU115_INDEX_URL,
TorchSource::Pyx => &PYX_CU115_INDEX_URL,
},
Self::Cu114 => match source {
TorchSource::PyTorch => &PYTORCH_CU114_INDEX_URL,
TorchSource::Pyx => &PYX_CU114_INDEX_URL,
},
Self::Cu113 => match source {
TorchSource::PyTorch => &PYTORCH_CU113_INDEX_URL,
TorchSource::Pyx => &PYX_CU113_INDEX_URL,
},
Self::Cu112 => match source {
TorchSource::PyTorch => &PYTORCH_CU112_INDEX_URL,
TorchSource::Pyx => &PYX_CU112_INDEX_URL,
},
Self::Cu111 => match source {
TorchSource::PyTorch => &PYTORCH_CU111_INDEX_URL,
TorchSource::Pyx => &PYX_CU111_INDEX_URL,
},
Self::Cu110 => match source {
TorchSource::PyTorch => &PYTORCH_CU110_INDEX_URL,
TorchSource::Pyx => &PYX_CU110_INDEX_URL,
},
Self::Cu102 => match source {
TorchSource::PyTorch => &PYTORCH_CU102_INDEX_URL,
TorchSource::Pyx => &PYX_CU102_INDEX_URL,
},
Self::Cu101 => match source {
TorchSource::PyTorch => &PYTORCH_CU101_INDEX_URL,
TorchSource::Pyx => &PYX_CU101_INDEX_URL,
},
Self::Cu100 => match source {
TorchSource::PyTorch => &PYTORCH_CU100_INDEX_URL,
TorchSource::Pyx => &PYX_CU100_INDEX_URL,
},
Self::Cu92 => match source {
TorchSource::PyTorch => &PYTORCH_CU92_INDEX_URL,
TorchSource::Pyx => &PYX_CU92_INDEX_URL,
},
Self::Cu91 => match source {
TorchSource::PyTorch => &PYTORCH_CU91_INDEX_URL,
TorchSource::Pyx => &PYX_CU91_INDEX_URL,
},
Self::Cu90 => match source {
TorchSource::PyTorch => &PYTORCH_CU90_INDEX_URL,
TorchSource::Pyx => &PYX_CU90_INDEX_URL,
},
Self::Cu80 => match source {
TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
TorchSource::Pyx => &PYX_CU80_INDEX_URL,
},
Self::Rocm63 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM63_INDEX_URL,
},
Self::Rocm624 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM624_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM624_INDEX_URL,
},
Self::Rocm62 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM62_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM62_INDEX_URL,
},
Self::Rocm61 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM61_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM61_INDEX_URL,
},
Self::Rocm60 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM60_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM60_INDEX_URL,
},
Self::Rocm57 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM57_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM57_INDEX_URL,
},
Self::Rocm56 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM56_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM56_INDEX_URL,
},
Self::Rocm55 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM55_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM55_INDEX_URL,
},
Self::Rocm542 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM542_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM542_INDEX_URL,
},
Self::Rocm54 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM54_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM54_INDEX_URL,
},
Self::Rocm53 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM53_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM53_INDEX_URL,
},
Self::Rocm52 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM52_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM52_INDEX_URL,
},
Self::Rocm511 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM511_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM511_INDEX_URL,
},
Self::Rocm42 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM42_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM42_INDEX_URL,
},
Self::Rocm41 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM41_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM41_INDEX_URL,
},
Self::Rocm401 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM401_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM401_INDEX_URL,
},
Self::Xpu => match source {
TorchSource::PyTorch => &PYTORCH_XPU_INDEX_URL,
TorchSource::Pyx => &PYX_XPU_INDEX_URL,
},
}
}
@ -508,6 +713,16 @@ impl TorchBackend {
return None;
}
path_segments.next()?
} else if index.host_str() == API_BASE_URL.strip_prefix("https://") {
// E.g., `https://api.pyx.dev/simple/astral-sh/cu124`
let mut path_segments = index.path_segments()?;
if path_segments.next() != Some("simple") {
return None;
}
if path_segments.next() != Some("astral-sh") {
return None;
}
path_segments.next()?
} else {
return None;
};
@ -619,7 +834,6 @@ impl FromStr for TorchBackend {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"cpu" => Ok(Self::Cpu),
"cu129" => Ok(Self::Cu129),
"cu128" => Ok(Self::Cu128),
"cu126" => Ok(Self::Cu126),
"cu125" => Ok(Self::Cu125),
@ -669,11 +883,10 @@ impl FromStr for TorchBackend {
/// Linux CUDA driver versions and the corresponding CUDA versions.
///
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 25]> = LazyLock::new(|| {
static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
[
// Table 2 from
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
(TorchBackend::Cu129, Version::new([525, 60, 13])),
(TorchBackend::Cu128, Version::new([525, 60, 13])),
(TorchBackend::Cu126, Version::new([525, 60, 13])),
(TorchBackend::Cu125, Version::new([525, 60, 13])),
@ -708,11 +921,10 @@ static LINUX_CUDA_DRIVERS: LazyLock<[(TorchBackend, Version); 25]> = LazyLock::n
/// Windows CUDA driver versions and the corresponding CUDA versions.
///
/// See: <https://github.com/pmeier/light-the-torch/blob/33397cbe45d07b51ad8ee76b004571a4c236e37f/light_the_torch/_cb.py#L150-L213>
static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 25]> = LazyLock::new(|| {
static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 24]> = LazyLock::new(|| {
[
// Table 2 from
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
(TorchBackend::Cu129, Version::new([528, 33])),
(TorchBackend::Cu128, Version::new([528, 33])),
(TorchBackend::Cu126, Version::new([528, 33])),
(TorchBackend::Cu125, Version::new([528, 33])),
@ -811,89 +1023,267 @@ static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]>
]
});
static CPU_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CPU_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cpu").unwrap());
static CU129_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU129_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu129").unwrap());
static CU128_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU128_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu128").unwrap());
static CU126_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU126_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu126").unwrap());
static CU125_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU125_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu125").unwrap());
static CU124_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU124_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu124").unwrap());
static CU123_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU123_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu123").unwrap());
static CU122_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU122_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu122").unwrap());
static CU121_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU121_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu121").unwrap());
static CU120_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU120_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu120").unwrap());
static CU118_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU118_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu118").unwrap());
static CU117_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU117_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu117").unwrap());
static CU116_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU116_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu116").unwrap());
static CU115_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU115_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu115").unwrap());
static CU114_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU114_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu114").unwrap());
static CU113_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU113_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu113").unwrap());
static CU112_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU112_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu112").unwrap());
static CU111_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU111_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu111").unwrap());
static CU110_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU110_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu110").unwrap());
static CU102_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU102_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu102").unwrap());
static CU101_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU101_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu101").unwrap());
static CU100_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU100_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu100").unwrap());
static CU92_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU92_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu92").unwrap());
static CU91_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU91_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu91").unwrap());
static CU90_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU90_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
static CU80_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
static ROCM63_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap());
static ROCM624_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM624_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2.4").unwrap());
static ROCM62_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM62_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.2").unwrap());
static ROCM61_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM61_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.1").unwrap());
static ROCM60_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM60_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.0").unwrap());
static ROCM57_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM57_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.7").unwrap());
static ROCM56_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM56_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.6").unwrap());
static ROCM55_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM55_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.5").unwrap());
static ROCM542_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM542_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4.2").unwrap());
static ROCM54_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM54_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.4").unwrap());
static ROCM53_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM53_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.3").unwrap());
static ROCM52_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM52_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.2").unwrap());
static ROCM511_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM511_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm5.1.1").unwrap());
static ROCM42_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM42_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.2").unwrap());
static ROCM41_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM41_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.1").unwrap());
static ROCM401_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_ROCM401_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.0.1").unwrap());
static XPU_INDEX_URL: LazyLock<IndexUrl> =
static PYTORCH_XPU_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/xpu").unwrap());
static API_BASE_URL: LazyLock<Cow<'static, str>> = LazyLock::new(|| {
std::env::var(EnvVars::PYX_API_URL)
.map(Cow::Owned)
.unwrap_or(Cow::Borrowed("https://api.pyx.dev"))
});
static PYX_CPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cpu")).unwrap()
});
static PYX_CU129_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu129")).unwrap()
});
static PYX_CU128_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu128")).unwrap()
});
static PYX_CU126_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu126")).unwrap()
});
static PYX_CU125_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu125")).unwrap()
});
static PYX_CU124_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu124")).unwrap()
});
static PYX_CU123_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu123")).unwrap()
});
static PYX_CU122_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu122")).unwrap()
});
static PYX_CU121_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu121")).unwrap()
});
static PYX_CU120_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu120")).unwrap()
});
static PYX_CU118_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu118")).unwrap()
});
static PYX_CU117_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu117")).unwrap()
});
static PYX_CU116_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu116")).unwrap()
});
static PYX_CU115_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu115")).unwrap()
});
static PYX_CU114_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu114")).unwrap()
});
static PYX_CU113_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu113")).unwrap()
});
static PYX_CU112_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu112")).unwrap()
});
static PYX_CU111_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu111")).unwrap()
});
static PYX_CU110_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu110")).unwrap()
});
static PYX_CU102_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu102")).unwrap()
});
static PYX_CU101_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu101")).unwrap()
});
static PYX_CU100_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu100")).unwrap()
});
static PYX_CU92_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu92")).unwrap()
});
static PYX_CU91_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu91")).unwrap()
});
static PYX_CU90_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu90")).unwrap()
});
static PYX_CU80_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
});
static PYX_ROCM63_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap()
});
static PYX_ROCM624_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2.4")).unwrap()
});
static PYX_ROCM62_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.2")).unwrap()
});
static PYX_ROCM61_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.1")).unwrap()
});
static PYX_ROCM60_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.0")).unwrap()
});
static PYX_ROCM57_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.7")).unwrap()
});
static PYX_ROCM56_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.6")).unwrap()
});
static PYX_ROCM55_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.5")).unwrap()
});
static PYX_ROCM542_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4.2")).unwrap()
});
static PYX_ROCM54_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.4")).unwrap()
});
static PYX_ROCM53_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.3")).unwrap()
});
static PYX_ROCM52_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.2")).unwrap()
});
static PYX_ROCM511_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm5.1.1")).unwrap()
});
static PYX_ROCM42_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.2")).unwrap()
});
static PYX_ROCM41_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.1")).unwrap()
});
static PYX_ROCM401_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm4.0.1")).unwrap()
});
static PYX_XPU_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/xpu")).unwrap()
});

View File

@ -45,7 +45,7 @@ use uv_resolver::{
ResolverEnvironment,
};
use uv_static::EnvVars;
use uv_torch::{TorchMode, TorchStrategy};
use uv_torch::{TorchMode, TorchSource, TorchStrategy};
use uv_types::{EmptyInstalledPackages, HashStrategy};
use uv_warnings::{warn_user, warn_user_once};
use uv_workspace::WorkspaceCache;
@ -402,8 +402,16 @@ pub(crate) async fn pip_compile(
// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(
mode,
source,
python_platform
.map(TargetTriple::platform)
.as_ref()

View File

@ -35,7 +35,7 @@ use uv_resolver::{
DependencyMode, ExcludeNewer, FlatIndex, OptionsBuilder, PrereleaseMode, PylockToml,
PythonRequirement, ResolutionMode, ResolverEnvironment,
};
use uv_torch::{TorchMode, TorchStrategy};
use uv_torch::{TorchMode, TorchSource, TorchStrategy};
use uv_types::HashStrategy;
use uv_warnings::{warn_user, warn_user_once};
use uv_workspace::WorkspaceCache;
@ -365,8 +365,16 @@ pub(crate) async fn pip_install(
// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(
mode,
source,
python_platform
.map(TargetTriple::platform)
.as_ref()

View File

@ -33,7 +33,7 @@ use uv_resolver::{
DependencyMode, ExcludeNewer, FlatIndex, OptionsBuilder, PrereleaseMode, PylockToml,
PythonRequirement, ResolutionMode, ResolverEnvironment,
};
use uv_torch::{TorchMode, TorchStrategy};
use uv_torch::{TorchMode, TorchSource, TorchStrategy};
use uv_types::HashStrategy;
use uv_warnings::{warn_user, warn_user_once};
use uv_workspace::WorkspaceCache;
@ -289,8 +289,16 @@ pub(crate) async fn pip_sync(
// Determine the PyTorch backend.
let torch_backend = torch_backend
.map(|mode| {
let source = if uv_auth::PyxTokenStore::from_settings()
.is_ok_and(|store| store.has_credentials())
{
TorchSource::Pyx
} else {
TorchSource::default()
};
TorchStrategy::from_mode(
mode,
source,
python_platform
.map(TargetTriple::platform)
.as_ref()