Add ROCm 6.4 to `--torch-backend=auto` (#16919)

## Summary

Closes https://github.com/astral-sh/uv/issues/16917.
This commit is contained in:
Charlie Marsh 2025-12-01 20:27:20 -05:00 committed by GitHub
parent 3347e196bb
commit 2cdbf9e547
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 1 deletions

View File

@ -113,6 +113,10 @@ pub enum TorchMode {
Cu90, Cu90,
/// Use the PyTorch index for CUDA 8.0. /// Use the PyTorch index for CUDA 8.0.
Cu80, Cu80,
/// Use the PyTorch index for ROCm 6.4.
#[serde(rename = "rocm6.4")]
#[cfg_attr(feature = "clap", clap(name = "rocm6.4"))]
Rocm64,
/// Use the PyTorch index for ROCm 6.3. /// Use the PyTorch index for ROCm 6.3.
#[serde(rename = "rocm6.3")] #[serde(rename = "rocm6.3")]
#[cfg_attr(feature = "clap", clap(name = "rocm6.3"))] #[cfg_attr(feature = "clap", clap(name = "rocm6.3"))]
@ -272,6 +276,7 @@ impl TorchStrategy {
TorchMode::Cu91 => TorchBackend::Cu91, TorchMode::Cu91 => TorchBackend::Cu91,
TorchMode::Cu90 => TorchBackend::Cu90, TorchMode::Cu90 => TorchBackend::Cu90,
TorchMode::Cu80 => TorchBackend::Cu80, TorchMode::Cu80 => TorchBackend::Cu80,
TorchMode::Rocm64 => TorchBackend::Rocm64,
TorchMode::Rocm63 => TorchBackend::Rocm63, TorchMode::Rocm63 => TorchBackend::Rocm63,
TorchMode::Rocm624 => TorchBackend::Rocm624, TorchMode::Rocm624 => TorchBackend::Rocm624,
TorchMode::Rocm62 => TorchBackend::Rocm62, TorchMode::Rocm62 => TorchBackend::Rocm62,
@ -516,6 +521,7 @@ pub enum TorchBackend {
Cu91, Cu91,
Cu90, Cu90,
Cu80, Cu80,
Rocm64,
Rocm63, Rocm63,
Rocm624, Rocm624,
Rocm62, Rocm62,
@ -647,6 +653,10 @@ impl TorchBackend {
TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL, TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
TorchSource::Pyx => &PYX_CU80_INDEX_URL, TorchSource::Pyx => &PYX_CU80_INDEX_URL,
}, },
Self::Rocm64 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM64_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM64_INDEX_URL,
},
Self::Rocm63 => match source { Self::Rocm63 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL, TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM63_INDEX_URL, TorchSource::Pyx => &PYX_ROCM63_INDEX_URL,
@ -774,6 +784,7 @@ impl TorchBackend {
Self::Cu91 => Some(Version::new([9, 1])), Self::Cu91 => Some(Version::new([9, 1])),
Self::Cu90 => Some(Version::new([9, 0])), Self::Cu90 => Some(Version::new([9, 0])),
Self::Cu80 => Some(Version::new([8, 0])), Self::Cu80 => Some(Version::new([8, 0])),
Self::Rocm64 => None,
Self::Rocm63 => None, Self::Rocm63 => None,
Self::Rocm624 => None, Self::Rocm624 => None,
Self::Rocm62 => None, Self::Rocm62 => None,
@ -824,6 +835,7 @@ impl TorchBackend {
Self::Cu91 => None, Self::Cu91 => None,
Self::Cu90 => None, Self::Cu90 => None,
Self::Cu80 => None, Self::Cu80 => None,
Self::Rocm64 => Some(Version::new([6, 4])),
Self::Rocm63 => Some(Version::new([6, 3])), Self::Rocm63 => Some(Version::new([6, 3])),
Self::Rocm624 => Some(Version::new([6, 2, 4])), Self::Rocm624 => Some(Version::new([6, 2, 4])),
Self::Rocm62 => Some(Version::new([6, 2])), Self::Rocm62 => Some(Version::new([6, 2])),
@ -877,6 +889,7 @@ impl FromStr for TorchBackend {
"cu91" => Ok(Self::Cu91), "cu91" => Ok(Self::Cu91),
"cu90" => Ok(Self::Cu90), "cu90" => Ok(Self::Cu90),
"cu80" => Ok(Self::Cu80), "cu80" => Ok(Self::Cu80),
"rocm6.4" => Ok(Self::Rocm64),
"rocm6.3" => Ok(Self::Rocm63), "rocm6.3" => Ok(Self::Rocm63),
"rocm6.2.4" => Ok(Self::Rocm624), "rocm6.2.4" => Ok(Self::Rocm624),
"rocm6.2" => Ok(Self::Rocm62), "rocm6.2" => Ok(Self::Rocm62),
@ -991,9 +1004,21 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock
/// ///
/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>; /// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
/// however, this list includes a broader array of GPUs than those in the matrix. /// however, this list includes a broader array of GPUs than those in the matrix.
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> = static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 55]> =
LazyLock::new(|| { LazyLock::new(|| {
[ [
// ROCm 6.4
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1101),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1102),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1200),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1201),
// ROCm 6.3 // ROCm 6.3
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900), (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906), (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
@ -1100,6 +1125,8 @@ static PYTORCH_CU90_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap()); LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> = static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap()); LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
static PYTORCH_ROCM64_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.4").unwrap());
static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> = static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap()); LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap());
static PYTORCH_ROCM624_INDEX_URL: LazyLock<IndexUrl> = static PYTORCH_ROCM624_INDEX_URL: LazyLock<IndexUrl> =
@ -1248,6 +1275,10 @@ static PYX_CU80_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL; let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap() IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
}); });
static PYX_ROCM64_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.4")).unwrap()
});
static PYX_ROCM63_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| { static PYX_ROCM63_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL; let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap() IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap()

View File

@ -4301,6 +4301,7 @@ by <code>--python-version</code>.</p>
<li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li> <li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li>
<li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li> <li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li>
<li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li> <li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li>
<li><code>rocm6.4</code>: Use the PyTorch index for ROCm 6.4</li>
<li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li> <li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li>
<li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li> <li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li>
<li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li> <li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li>
@ -4586,6 +4587,7 @@ be used with caution, as it can modify the system Python installation.</p>
<li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li> <li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li>
<li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li> <li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li>
<li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li> <li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li>
<li><code>rocm6.4</code>: Use the PyTorch index for ROCm 6.4</li>
<li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li> <li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li>
<li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li> <li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li>
<li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li> <li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li>
@ -4899,6 +4901,7 @@ should be used with caution, as it can modify the system Python installation.</p
<li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li> <li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li>
<li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li> <li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li>
<li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li> <li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li>
<li><code>rocm6.4</code>: Use the PyTorch index for ROCm 6.4</li>
<li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li> <li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li>
<li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li> <li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li>
<li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li> <li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li>

5
uv.schema.json generated
View File

@ -2647,6 +2647,11 @@
"type": "string", "type": "string",
"const": "cu80" "const": "cu80"
}, },
{
"description": "Use the PyTorch index for ROCm 6.4.",
"type": "string",
"const": "rocm6.4"
},
{ {
"description": "Use the PyTorch index for ROCm 6.3.", "description": "Use the PyTorch index for ROCm 6.3.",
"type": "string", "type": "string",