mirror of https://github.com/astral-sh/uv
97 lines
2.8 KiB
Rust
97 lines
2.8 KiB
Rust
use std::str::FromStr;
|
|
|
|
use pubgrub::Ranges;
|
|
|
|
use uv_normalize::PackageName;
|
|
use uv_pep440::Version;
|
|
use uv_redacted::DisplaySafeUrl;
|
|
use uv_torch::TorchBackend;
|
|
|
|
use crate::pubgrub::{PubGrubDependency, PubGrubPackage, PubGrubPackageInner};
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub(super) struct SystemDependency {
|
|
/// The name of the system dependency (e.g., `cuda`).
|
|
name: PackageName,
|
|
/// The version of the system dependency (e.g., `12.4`).
|
|
version: Version,
|
|
}
|
|
|
|
impl SystemDependency {
|
|
/// Extract a [`SystemDependency`] from an index URL.
|
|
///
|
|
/// For example, given `https://download.pytorch.org/whl/cu124`, returns CUDA 12.4.
|
|
pub(super) fn from_index(index: &DisplaySafeUrl) -> Option<Self> {
|
|
let backend = TorchBackend::from_index(index)?;
|
|
if let Some(cuda_version) = backend.cuda_version() {
|
|
Some(Self {
|
|
name: PackageName::from_str("cuda").unwrap(),
|
|
version: cuda_version,
|
|
})
|
|
} else {
|
|
backend.rocm_version().map(|rocm_version| Self {
|
|
name: PackageName::from_str("rocm").unwrap(),
|
|
version: rocm_version,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for SystemDependency {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}@{}", self.name, self.version)
|
|
}
|
|
}
|
|
|
|
impl From<SystemDependency> for PubGrubDependency {
|
|
fn from(value: SystemDependency) -> Self {
|
|
Self {
|
|
package: PubGrubPackage::from(PubGrubPackageInner::System(value.name)),
|
|
version: Ranges::singleton(value.version),
|
|
parent: None,
|
|
url: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::str::FromStr;
|
|
|
|
use uv_normalize::PackageName;
|
|
use uv_pep440::Version;
|
|
use uv_redacted::DisplaySafeUrl;
|
|
|
|
use crate::resolver::system::SystemDependency;
|
|
|
|
#[test]
|
|
fn pypi() {
|
|
let url = DisplaySafeUrl::parse("https://pypi.org/simple").unwrap();
|
|
assert_eq!(SystemDependency::from_index(&url), None);
|
|
}
|
|
|
|
#[test]
|
|
fn pytorch_cuda_12_4() {
|
|
let url = DisplaySafeUrl::parse("https://download.pytorch.org/whl/cu124").unwrap();
|
|
assert_eq!(
|
|
SystemDependency::from_index(&url),
|
|
Some(SystemDependency {
|
|
name: PackageName::from_str("cuda").unwrap(),
|
|
version: Version::new([12, 4]),
|
|
})
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn pytorch_cpu() {
|
|
let url = DisplaySafeUrl::parse("https://download.pytorch.org/whl/cpu").unwrap();
|
|
assert_eq!(SystemDependency::from_index(&url), None);
|
|
}
|
|
|
|
#[test]
|
|
fn pytorch_xpu() {
|
|
let url = DisplaySafeUrl::parse("https://download.pytorch.org/whl/xpu").unwrap();
|
|
assert_eq!(SystemDependency::from_index(&url), None);
|
|
}
|
|
}
|