From 2cdbf9e547ebdbcede70141cc26aa59f1a724ecc Mon Sep 17 00:00:00 2001
From: Charlie Marsh
Date: Mon, 1 Dec 2025 20:27:20 -0500
Subject: [PATCH] Add ROCm 6.4 to `--torch-backend=auto` (#16919)
## Summary
Closes https://github.com/astral-sh/uv/issues/16917.
---
crates/uv-torch/src/backend.rs | 33 ++++++++++++++++++++++++++++++++-
docs/reference/cli.md | 3 +++
uv.schema.json | 5 +++++
3 files changed, 40 insertions(+), 1 deletion(-)
diff --git a/crates/uv-torch/src/backend.rs b/crates/uv-torch/src/backend.rs
index 773fd619b..7437744d5 100644
--- a/crates/uv-torch/src/backend.rs
+++ b/crates/uv-torch/src/backend.rs
@@ -113,6 +113,10 @@ pub enum TorchMode {
Cu90,
/// Use the PyTorch index for CUDA 8.0.
Cu80,
+ /// Use the PyTorch index for ROCm 6.4.
+ #[serde(rename = "rocm6.4")]
+ #[cfg_attr(feature = "clap", clap(name = "rocm6.4"))]
+ Rocm64,
/// Use the PyTorch index for ROCm 6.3.
#[serde(rename = "rocm6.3")]
#[cfg_attr(feature = "clap", clap(name = "rocm6.3"))]
@@ -272,6 +276,7 @@ impl TorchStrategy {
TorchMode::Cu91 => TorchBackend::Cu91,
TorchMode::Cu90 => TorchBackend::Cu90,
TorchMode::Cu80 => TorchBackend::Cu80,
+ TorchMode::Rocm64 => TorchBackend::Rocm64,
TorchMode::Rocm63 => TorchBackend::Rocm63,
TorchMode::Rocm624 => TorchBackend::Rocm624,
TorchMode::Rocm62 => TorchBackend::Rocm62,
@@ -516,6 +521,7 @@ pub enum TorchBackend {
Cu91,
Cu90,
Cu80,
+ Rocm64,
Rocm63,
Rocm624,
Rocm62,
@@ -647,6 +653,10 @@ impl TorchBackend {
TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
TorchSource::Pyx => &PYX_CU80_INDEX_URL,
},
+ Self::Rocm64 => match source {
+ TorchSource::PyTorch => &PYTORCH_ROCM64_INDEX_URL,
+ TorchSource::Pyx => &PYX_ROCM64_INDEX_URL,
+ },
Self::Rocm63 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM63_INDEX_URL,
@@ -774,6 +784,7 @@ impl TorchBackend {
Self::Cu91 => Some(Version::new([9, 1])),
Self::Cu90 => Some(Version::new([9, 0])),
Self::Cu80 => Some(Version::new([8, 0])),
+ Self::Rocm64 => None,
Self::Rocm63 => None,
Self::Rocm624 => None,
Self::Rocm62 => None,
@@ -824,6 +835,7 @@ impl TorchBackend {
Self::Cu91 => None,
Self::Cu90 => None,
Self::Cu80 => None,
+ Self::Rocm64 => Some(Version::new([6, 4])),
Self::Rocm63 => Some(Version::new([6, 3])),
Self::Rocm624 => Some(Version::new([6, 2, 4])),
Self::Rocm62 => Some(Version::new([6, 2])),
@@ -877,6 +889,7 @@ impl FromStr for TorchBackend {
"cu91" => Ok(Self::Cu91),
"cu90" => Ok(Self::Cu90),
"cu80" => Ok(Self::Cu80),
+ "rocm6.4" => Ok(Self::Rocm64),
"rocm6.3" => Ok(Self::Rocm63),
"rocm6.2.4" => Ok(Self::Rocm624),
"rocm6.2" => Ok(Self::Rocm62),
@@ -991,9 +1004,21 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock
///
/// AMD also provides a compatibility matrix: ;
/// however, this list includes a broader array of GPUs than those in the matrix.
-static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> =
+static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 55]> =
LazyLock::new(|| {
[
+ // ROCm 6.4
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx900),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx906),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx908),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx90a),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx942),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1030),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1100),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1101),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1102),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1200),
+ (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1201),
// ROCm 6.3
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
@@ -1100,6 +1125,8 @@ static PYTORCH_CU90_INDEX_URL: LazyLock =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
static PYTORCH_CU80_INDEX_URL: LazyLock =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
+static PYTORCH_ROCM64_INDEX_URL: LazyLock =
+ LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.4").unwrap());
static PYTORCH_ROCM63_INDEX_URL: LazyLock =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap());
static PYTORCH_ROCM624_INDEX_URL: LazyLock =
@@ -1248,6 +1275,10 @@ static PYX_CU80_INDEX_URL: LazyLock = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
});
+static PYX_ROCM64_INDEX_URL: LazyLock = LazyLock::new(|| {
+ let api_base_url = &*PYX_API_BASE_URL;
+ IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.4")).unwrap()
+});
static PYX_ROCM63_INDEX_URL: LazyLock = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap()
diff --git a/docs/reference/cli.md b/docs/reference/cli.md
index 330a9ec94..e32df391e 100644
--- a/docs/reference/cli.md
+++ b/docs/reference/cli.md
@@ -4301,6 +4301,7 @@ by --python-version.
cu91: Use the PyTorch index for CUDA 9.1
cu90: Use the PyTorch index for CUDA 9.0
cu80: Use the PyTorch index for CUDA 8.0
+rocm6.4: Use the PyTorch index for ROCm 6.4
rocm6.3: Use the PyTorch index for ROCm 6.3
rocm6.2.4: Use the PyTorch index for ROCm 6.2.4
rocm6.2: Use the PyTorch index for ROCm 6.2
@@ -4586,6 +4587,7 @@ be used with caution, as it can modify the system Python installation.
cu91: Use the PyTorch index for CUDA 9.1
cu90: Use the PyTorch index for CUDA 9.0
cu80: Use the PyTorch index for CUDA 8.0
+rocm6.4: Use the PyTorch index for ROCm 6.4
rocm6.3: Use the PyTorch index for ROCm 6.3
rocm6.2.4: Use the PyTorch index for ROCm 6.2.4
rocm6.2: Use the PyTorch index for ROCm 6.2
@@ -4899,6 +4901,7 @@ should be used with caution, as it can modify the system Python installation.cu91: Use the PyTorch index for CUDA 9.1
cu90: Use the PyTorch index for CUDA 9.0
cu80: Use the PyTorch index for CUDA 8.0
+rocm6.4: Use the PyTorch index for ROCm 6.4
rocm6.3: Use the PyTorch index for ROCm 6.3
rocm6.2.4: Use the PyTorch index for ROCm 6.2.4
rocm6.2: Use the PyTorch index for ROCm 6.2
diff --git a/uv.schema.json b/uv.schema.json
index 7cd16427a..6fa38f71e 100644
--- a/uv.schema.json
+++ b/uv.schema.json
@@ -2647,6 +2647,11 @@
"type": "string",
"const": "cu80"
},
+ {
+ "description": "Use the PyTorch index for ROCm 6.4.",
+ "type": "string",
+ "const": "rocm6.4"
+ },
{
"description": "Use the PyTorch index for ROCm 6.3.",
"type": "string",