diff --git a/crates/uv-torch/src/backend.rs b/crates/uv-torch/src/backend.rs index 7437744d5..2bb8374fa 100644 --- a/crates/uv-torch/src/backend.rs +++ b/crates/uv-torch/src/backend.rs @@ -310,7 +310,10 @@ impl TorchStrategy { TorchSource::PyTorch => { matches!( package_name.as_str(), - "torch" + "pytorch-triton" + | "pytorch-triton-rocm" + | "pytorch-triton-xpu" + | "torch" | "torcharrow" | "torchaudio" | "torchcsprng" @@ -320,9 +323,6 @@ impl TorchStrategy { | "torchtext" | "torchvision" | "triton" - | "pytorch-triton" - | "pytorch-triton-rocm" - | "pytorch-triton-xpu" ) } TorchSource::Pyx => { @@ -334,12 +334,14 @@ impl TorchStrategy { | "megablocks" | "natten" | "pyg-lib" + | "pytorch-triton" + | "pytorch-triton-rocm" + | "pytorch-triton-xpu" + | "torch" | "torch-cluster" | "torch-scatter" | "torch-sparse" | "torch-spline-conv" - | "vllm" - | "torch" | "torcharrow" | "torchaudio" | "torchcsprng" @@ -349,9 +351,7 @@ impl TorchStrategy { | "torchtext" | "torchvision" | "triton" - | "pytorch-triton" - | "pytorch-triton-rocm" - | "pytorch-triton-xpu" + | "vllm" ) } } @@ -365,12 +365,11 @@ impl TorchStrategy { pub fn has_system_dependency(&self, package_name: &PackageName) -> bool { matches!( package_name.as_str(), - "flash-attn" + "deepspeed" + | "flash-attn" | "flash-attn-3" | "megablocks" | "natten" - | "deepspeed" - | "vllm" | "torch" | "torcharrow" | "torchaudio" @@ -379,6 +378,7 @@ impl TorchStrategy { | "torchdistx" | "torchtext" | "torchvision" + | "vllm" ) }