Directly ask the NVIDIA kernel driver what version it is

This is the same API called by `nvidia-smi`, but doing it ourselves is both
much faster and more robust to installation conditions (e.g., nvidia-smi not
being installed).
This commit is contained in:
Geoffrey Thomas 2025-07-19 06:45:36 +00:00
parent a8bb7be52b
commit 96d4d5b651
3 changed files with 65 additions and 3 deletions

1
Cargo.lock generated
View File

@ -5942,6 +5942,7 @@ dependencies = [
"clap",
"either",
"fs-err 3.1.1",
"nix 0.30.1",
"schemars",
"serde",
"thiserror 2.0.12",

View File

@ -19,6 +19,7 @@ uv-static = { workspace = true }
clap = { workspace = true, optional = true }
either = { workspace = true }
fs-err = { workspace = true }
nix = { workspace = true, features = ["ioctl"] }
schemars = { workspace = true, optional = true }
serde = { workspace = true }
thiserror = { workspace = true }

View File

@ -57,9 +57,10 @@ impl Accelerator {
/// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
/// 3. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
/// 4. `/proc/driver/nvidia/version`, which contains the driver version among other information.
/// 5. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
/// 6. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
/// 7. `/sys/bus/pci/devices`, filtering for the Intel GPU via PCI.
/// 5. `/dev/nvidiactl` via the NVIDIA ioctl interface.
/// 6. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
/// 7. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
/// 8. `/sys/bus/pci/devices`, filtering for the Intel GPU via PCI.
pub fn detect() -> Result<Option<Self>, AcceleratorError> {
// Constants used for PCI device detection.
const PCI_BASE_CLASS_MASK: u32 = 0x00ff_0000;
@ -119,6 +120,65 @@ impl Accelerator {
Err(e) => return Err(e.into()),
}
// Query `/dev/nvidiactl`.
if let Ok(nvctl) = fs_err::File::open("/dev/nvidiactl") {
// https://github.com/NVIDIA/open-gpu-kernel-modules implements the same driver API as
// the closed-source one. These structures are based on
// kernel-open/common/inc/nv-ioctl.h and nv-ioctl-numbers.h in that repo.
#[repr(C)]
struct nv_ioctl_rm_api_version {
cmd: u32,
reply: u32,
version_string: [u8; 64],
}
const NV_RM_API_VERSION_CMD_QUERY: u32 = b'2' as _;
const NV_RM_API_VERSION_REPLY_RECOGNIZED: u32 = 1;
const NV_IOCTL_MAGIC: u32 = b'F' as _;
const NV_IOCTL_BASE: u32 = 200;
const NV_ESC_CHECK_VERSION_STR: u32 = NV_IOCTL_BASE + 10;
let mut query = nv_ioctl_rm_api_version {
cmd: NV_RM_API_VERSION_CMD_QUERY,
reply: 0,
version_string: [0; 64],
};
nix::ioctl_readwrite!(
nv_esc_check_version_str,
NV_IOCTL_MAGIC,
NV_ESC_CHECK_VERSION_STR,
nv_ioctl_rm_api_version
);
debug!("Imma firin ma ioctl");
use std::os::fd::AsRawFd;
#[allow(unsafe_code)]
match unsafe { nv_esc_check_version_str(nvctl.as_raw_fd(), &raw mut query) } {
Ok(_) => {
if query.reply == NV_RM_API_VERSION_REPLY_RECOGNIZED {
debug!("Hey, that worked!");
if let Ok(driver_version) =
std::ffi::CStr::from_bytes_until_nul(&query.version_string)
&& let Ok(driver_version) = driver_version.to_str()
&& let Ok(driver_version) = Version::from_str(driver_version)
{
debug!(
"Detected CUDA driver version from `/dev/nvidiactl`: {driver_version}"
);
return Ok(Some(Self::Cuda { driver_version }));
} else {
debug!("Unable to parse string in `/dev/nvidiactl` ioctl response");
}
} else {
debug!("Unexpected reply from `/dev/nvidiactl` ioctl");
}
}
Err(e) => {
debug!("`/dev/nvidiactl` ioctl failed: {e}");
}
}
}
// Query `nvidia-smi`.
if let Ok(output) = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=driver_version")