mirror of https://github.com/astral-sh/uv
Add ROCm 6.4 to `--torch-backend=auto` (#16919)
## Summary Closes https://github.com/astral-sh/uv/issues/16917.
This commit is contained in:
parent
3347e196bb
commit
2cdbf9e547
|
|
@ -113,6 +113,10 @@ pub enum TorchMode {
|
|||
Cu90,
|
||||
/// Use the PyTorch index for CUDA 8.0.
|
||||
Cu80,
|
||||
/// Use the PyTorch index for ROCm 6.4.
|
||||
#[serde(rename = "rocm6.4")]
|
||||
#[cfg_attr(feature = "clap", clap(name = "rocm6.4"))]
|
||||
Rocm64,
|
||||
/// Use the PyTorch index for ROCm 6.3.
|
||||
#[serde(rename = "rocm6.3")]
|
||||
#[cfg_attr(feature = "clap", clap(name = "rocm6.3"))]
|
||||
|
|
@ -272,6 +276,7 @@ impl TorchStrategy {
|
|||
TorchMode::Cu91 => TorchBackend::Cu91,
|
||||
TorchMode::Cu90 => TorchBackend::Cu90,
|
||||
TorchMode::Cu80 => TorchBackend::Cu80,
|
||||
TorchMode::Rocm64 => TorchBackend::Rocm64,
|
||||
TorchMode::Rocm63 => TorchBackend::Rocm63,
|
||||
TorchMode::Rocm624 => TorchBackend::Rocm624,
|
||||
TorchMode::Rocm62 => TorchBackend::Rocm62,
|
||||
|
|
@ -516,6 +521,7 @@ pub enum TorchBackend {
|
|||
Cu91,
|
||||
Cu90,
|
||||
Cu80,
|
||||
Rocm64,
|
||||
Rocm63,
|
||||
Rocm624,
|
||||
Rocm62,
|
||||
|
|
@ -647,6 +653,10 @@ impl TorchBackend {
|
|||
TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
|
||||
TorchSource::Pyx => &PYX_CU80_INDEX_URL,
|
||||
},
|
||||
Self::Rocm64 => match source {
|
||||
TorchSource::PyTorch => &PYTORCH_ROCM64_INDEX_URL,
|
||||
TorchSource::Pyx => &PYX_ROCM64_INDEX_URL,
|
||||
},
|
||||
Self::Rocm63 => match source {
|
||||
TorchSource::PyTorch => &PYTORCH_ROCM63_INDEX_URL,
|
||||
TorchSource::Pyx => &PYX_ROCM63_INDEX_URL,
|
||||
|
|
@ -774,6 +784,7 @@ impl TorchBackend {
|
|||
Self::Cu91 => Some(Version::new([9, 1])),
|
||||
Self::Cu90 => Some(Version::new([9, 0])),
|
||||
Self::Cu80 => Some(Version::new([8, 0])),
|
||||
Self::Rocm64 => None,
|
||||
Self::Rocm63 => None,
|
||||
Self::Rocm624 => None,
|
||||
Self::Rocm62 => None,
|
||||
|
|
@ -824,6 +835,7 @@ impl TorchBackend {
|
|||
Self::Cu91 => None,
|
||||
Self::Cu90 => None,
|
||||
Self::Cu80 => None,
|
||||
Self::Rocm64 => Some(Version::new([6, 4])),
|
||||
Self::Rocm63 => Some(Version::new([6, 3])),
|
||||
Self::Rocm624 => Some(Version::new([6, 2, 4])),
|
||||
Self::Rocm62 => Some(Version::new([6, 2])),
|
||||
|
|
@ -877,6 +889,7 @@ impl FromStr for TorchBackend {
|
|||
"cu91" => Ok(Self::Cu91),
|
||||
"cu90" => Ok(Self::Cu90),
|
||||
"cu80" => Ok(Self::Cu80),
|
||||
"rocm6.4" => Ok(Self::Rocm64),
|
||||
"rocm6.3" => Ok(Self::Rocm63),
|
||||
"rocm6.2.4" => Ok(Self::Rocm624),
|
||||
"rocm6.2" => Ok(Self::Rocm62),
|
||||
|
|
@ -991,9 +1004,21 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock
|
|||
///
|
||||
/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
|
||||
/// however, this list includes a broader array of GPUs than those in the matrix.
|
||||
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 44]> =
|
||||
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 55]> =
|
||||
LazyLock::new(|| {
|
||||
[
|
||||
// ROCm 6.4
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx900),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx906),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx908),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx90a),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx942),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1030),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1100),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1101),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1102),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1200),
|
||||
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx1201),
|
||||
// ROCm 6.3
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx900),
|
||||
(TorchBackend::Rocm63, AmdGpuArchitecture::Gfx906),
|
||||
|
|
@ -1100,6 +1125,8 @@ static PYTORCH_CU90_INDEX_URL: LazyLock<IndexUrl> =
|
|||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
|
||||
static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
|
||||
static PYTORCH_ROCM64_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.4").unwrap());
|
||||
static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> =
|
||||
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.3").unwrap());
|
||||
static PYTORCH_ROCM624_INDEX_URL: LazyLock<IndexUrl> =
|
||||
|
|
@ -1248,6 +1275,10 @@ static PYX_CU80_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
|
|||
let api_base_url = &*PYX_API_BASE_URL;
|
||||
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
|
||||
});
|
||||
static PYX_ROCM64_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
|
||||
let api_base_url = &*PYX_API_BASE_URL;
|
||||
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.4")).unwrap()
|
||||
});
|
||||
static PYX_ROCM63_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
|
||||
let api_base_url = &*PYX_API_BASE_URL;
|
||||
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.3")).unwrap()
|
||||
|
|
|
|||
|
|
@ -4301,6 +4301,7 @@ by <code>--python-version</code>.</p>
|
|||
<li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li>
|
||||
<li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li>
|
||||
<li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li>
|
||||
<li><code>rocm6.4</code>: Use the PyTorch index for ROCm 6.4</li>
|
||||
<li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li>
|
||||
<li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li>
|
||||
<li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li>
|
||||
|
|
@ -4586,6 +4587,7 @@ be used with caution, as it can modify the system Python installation.</p>
|
|||
<li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li>
|
||||
<li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li>
|
||||
<li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li>
|
||||
<li><code>rocm6.4</code>: Use the PyTorch index for ROCm 6.4</li>
|
||||
<li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li>
|
||||
<li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li>
|
||||
<li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li>
|
||||
|
|
@ -4899,6 +4901,7 @@ should be used with caution, as it can modify the system Python installation.</p
|
|||
<li><code>cu91</code>: Use the PyTorch index for CUDA 9.1</li>
|
||||
<li><code>cu90</code>: Use the PyTorch index for CUDA 9.0</li>
|
||||
<li><code>cu80</code>: Use the PyTorch index for CUDA 8.0</li>
|
||||
<li><code>rocm6.4</code>: Use the PyTorch index for ROCm 6.4</li>
|
||||
<li><code>rocm6.3</code>: Use the PyTorch index for ROCm 6.3</li>
|
||||
<li><code>rocm6.2.4</code>: Use the PyTorch index for ROCm 6.2.4</li>
|
||||
<li><code>rocm6.2</code>: Use the PyTorch index for ROCm 6.2</li>
|
||||
|
|
|
|||
|
|
@ -2647,6 +2647,11 @@
|
|||
"type": "string",
|
||||
"const": "cu80"
|
||||
},
|
||||
{
|
||||
"description": "Use the PyTorch index for ROCm 6.4.",
|
||||
"type": "string",
|
||||
"const": "rocm6.4"
|
||||
},
|
||||
{
|
||||
"description": "Use the PyTorch index for ROCm 6.3.",
|
||||
"type": "string",
|
||||
|
|
|
|||
Loading…
Reference in New Issue