diff --git a/crates/uv-torch/src/accelerator.rs b/crates/uv-torch/src/accelerator.rs
index 3165bd4c5..696adc9a1 100644
--- a/crates/uv-torch/src/accelerator.rs
+++ b/crates/uv-torch/src/accelerator.rs
@@ -1,3 +1,4 @@
+use std::path::Path;
use std::str::FromStr;
use tracing::debug;
@@ -13,6 +14,8 @@ pub enum AcceleratorError {
Version(#[from] uv_pep440::VersionParseError),
#[error(transparent)]
Utf8(#[from] std::string::FromUtf8Error),
+ #[error(transparent)]
+ ParseInt(#[from] std::num::ParseIntError),
#[error("Unknown AMD GPU architecture: {0}")]
UnknownAmdGpuArchitecture(String),
}
@@ -30,6 +33,10 @@ pub enum Accelerator {
Amd {
gpu_architecture: AmdGpuArchitecture,
},
+ /// The Intel GPU (XPU).
+ ///
+ /// Currently, Intel GPUs do not depend on a driver or toolkit version at this level.
+ Xpu,
}
impl std::fmt::Display for Accelerator {
@@ -37,21 +44,28 @@ impl std::fmt::Display for Accelerator {
match self {
Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"),
+ Self::Xpu => write!(f, "Intel GPU (XPU)"),
}
}
}
impl Accelerator {
- /// Detect the CUDA driver version from the system.
+ /// Detect the GPU driver and/or architecture version from the system.
///
/// Query, in order:
/// 1. The `UV_CUDA_DRIVER_VERSION` environment variable.
/// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
- /// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
- /// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information.
- /// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
- /// 5. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
+ /// 3. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
+ /// 4. `/proc/driver/nvidia/version`, which contains the driver version among other information.
+ /// 5. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
+ /// 6. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
+ /// 7. `/sys/bus/pci/devices`, filtering for the Intel GPU via PCI.
pub fn detect() -> Result