From 2cdbf9e547ebdbcede70141cc26aa59f1a724ecc Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Mon, 1 Dec 2025 20:27:20 -0500 Subject: [PATCH] Add ROCm 6.4 to `--torch-backend=auto` (#16919) ## Summary Closes https://github.com/astral-sh/uv/issues/16917. --- crates/uv-torch/src/backend.rs | 33 ++++++++++++++++++++++++++++++++- docs/reference/cli.md | 3 +++ uv.schema.json | 5 +++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/crates/uv-torch/src/backend.rs b/crates/uv-torch/src/backend.rs index 773fd619b..7437744d5 100644 --- a/crates/uv-torch/src/backend.rs +++ b/crates/uv-torch/src/backend.rs @@ -113,6 +113,10 @@ pub enum TorchMode { Cu90, /// Use the PyTorch index for CUDA 8.0. Cu80, + /// Use the PyTorch index for ROCm 6.4. + #[serde(rename = "rocm6.4")] + #[cfg_attr(feature = "clap", clap(name = "rocm6.4"))] + Rocm64, /// Use the PyTorch index for ROCm 6.3. #[serde(rename = "rocm6.3")] #[cfg_attr(feature = "clap", clap(name = "rocm6.3"))] @@ -272,6 +276,7 @@ impl TorchStrategy { TorchMode::Cu91 => TorchBackend::Cu91, TorchMode::Cu90 => TorchBackend::Cu90, TorchMode::Cu80 => TorchBackend::Cu80, + TorchMode::Rocm64 => TorchBackend::Rocm64, TorchMode::Rocm63 => TorchBackend::Rocm63, TorchMode::Rocm624 => TorchBackend::Rocm624, TorchMode::Rocm62 => TorchBackend::Rocm62, @@ -516,6 +521,7 @@ pub enum TorchBackend { Cu91, Cu90, Cu80, + Rocm64, Rocm63, Rocm624, Rocm62, @@ -647,6 +653,10 @@ impl TorchBackend { TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL, TorchSource::Pyx => &PYX_CU80_INDEX_URL, }, + Self::Rocm64 => match source { + TorchSource::PyTorch => &PYTORCH_ROCM64_INDEX_URL, + TorchSource::Pyx => &PYX_ROCM64_INDEX_URL, + }, Self::Rocm63 => match source { TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL, TorchSource::Pyx => &PYX_ROCM63_INDEX_URL, @@ -774,6 +784,7 @@ impl TorchBackend { Self::Cu91 => Some(Version::new([9, 1])), Self::Cu90 => Some(Version::new([9, 0])), Self::Cu80 => Some(Version::new([8, 0])), + Self::Rocm64 => None, Self::Rocm63 => None, Self::Rocm624 => None, Self::Rocm62 => None, @@ -824,6 +835,7 @@ impl TorchBackend { Self::Cu91 => None, Self::Cu90 => None, Self::Cu80 => None, + Self::Rocm64 => Some(Version::new([6, 4])), Self::Rocm63 => Some(Version::new([6, 3])), Self::Rocm624 => Some(Version::new([6, 2, 4])), Self::Rocm62 => Some(Version::new([6, 2])), @@ -877,6 +889,7 @@ impl FromStr for TorchBackend { "cu91" => Ok(Self::Cu91), "cu90" => Ok(Self::Cu90), "cu80" => Ok(Self::Cu80), + "rocm6.4" => Ok(Self::Rocm64), "rocm6.3" => Ok(Self::Rocm63), "rocm6.2.4" => Ok(Self::Rocm624), "rocm6.2" => Ok(Self::Rocm62), @@ -991,9 +1004,21 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock /// /// AMD also provides a compatibility matrix: ; /// however, this list includes a broader array of GPUs than those in the matrix. -static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> = +static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 55]> = LazyLock::new(|| { [ + // ROCm 6.4 + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx900), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx906), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx908), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx90a), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx942), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1030), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1100), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1101), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1102), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1200), + (TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1201), // ROCm 6.3 (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900), (TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906), @@ -1100,6 +1125,8 @@ static PYTORCH_CU90_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap()); static PYTORCH_CU80_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap()); +static PYTORCH_ROCM64_INDEX_URL: LazyLock = + LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.4").unwrap()); static PYTORCH_ROCM63_INDEX_URL: LazyLock = LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap()); static PYTORCH_ROCM624_INDEX_URL: LazyLock = @@ -1248,6 +1275,10 @@ static PYX_CU80_INDEX_URL: LazyLock = LazyLock::new(|| { let api_base_url = &*PYX_API_BASE_URL; IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap() }); +static PYX_ROCM64_INDEX_URL: LazyLock = LazyLock::new(|| { + let api_base_url = &*PYX_API_BASE_URL; + IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.4")).unwrap() +}); static PYX_ROCM63_INDEX_URL: LazyLock = LazyLock::new(|| { let api_base_url = &*PYX_API_BASE_URL; IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap() diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 330a9ec94..e32df391e 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -4301,6 +4301,7 @@ by --python-version.

  • cu91: Use the PyTorch index for CUDA 9.1
  • cu90: Use the PyTorch index for CUDA 9.0
  • cu80: Use the PyTorch index for CUDA 8.0
  • +
  • rocm6.4: Use the PyTorch index for ROCm 6.4
  • rocm6.3: Use the PyTorch index for ROCm 6.3
  • rocm6.2.4: Use the PyTorch index for ROCm 6.2.4
  • rocm6.2: Use the PyTorch index for ROCm 6.2
  • @@ -4586,6 +4587,7 @@ be used with caution, as it can modify the system Python installation.

  • cu91: Use the PyTorch index for CUDA 9.1
  • cu90: Use the PyTorch index for CUDA 9.0
  • cu80: Use the PyTorch index for CUDA 8.0
  • +
  • rocm6.4: Use the PyTorch index for ROCm 6.4
  • rocm6.3: Use the PyTorch index for ROCm 6.3
  • rocm6.2.4: Use the PyTorch index for ROCm 6.2.4
  • rocm6.2: Use the PyTorch index for ROCm 6.2
  • @@ -4899,6 +4901,7 @@ should be used with caution, as it can modify the system Python installation.

    cu91: Use the PyTorch index for CUDA 9.1
  • cu90: Use the PyTorch index for CUDA 9.0
  • cu80: Use the PyTorch index for CUDA 8.0
  • +
  • rocm6.4: Use the PyTorch index for ROCm 6.4
  • rocm6.3: Use the PyTorch index for ROCm 6.3
  • rocm6.2.4: Use the PyTorch index for ROCm 6.2.4
  • rocm6.2: Use the PyTorch index for ROCm 6.2
  • diff --git a/uv.schema.json b/uv.schema.json index 7cd16427a..6fa38f71e 100644 --- a/uv.schema.json +++ b/uv.schema.json @@ -2647,6 +2647,11 @@ "type": "string", "const": "cu80" }, + { + "description": "Use the PyTorch index for ROCm 6.4.", + "type": "string", + "const": "rocm6.4" + }, { "description": "Use the PyTorch index for ROCm 6.3.", "type": "string",