From bb73edb03bfc90a32440bcf8c8ffc9ac125cb37a Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 19 Jul 2024 17:56:09 -0400 Subject: [PATCH] Respect local versions for all user requirements (#5232) ## Summary This fixes a few bugs introduced by https://github.com/astral-sh/uv/pull/5104. I previously thought we could track conflicting locals the same way we track conflicting URLs in forks, but it turns out that ends up being very tricky. URL forks work because we prioritize directly URL requirements. We can't prioritize locals in the same way without conflicting with the URL prioritization (this may be possible but it's not trivial), so we run into issues where a correct resolution depends on the order in which dependencies are traversed. Instead, we track local versions across all forks in `Locals`. When applying a local version, we apply all locals with markers that intersect with the current fork. This way we end up applying some local versions without creating a fork. For example, given: ``` // pyproject.toml dependencies = [ "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", ] // requirements.in torch==2.0.0 . ``` We choose `2.0.0+cu118` in all cases. However, if a disjoint fork is created based on local versions, the resolver will choose the most compatible local when it narrows to a specific fork. Thus we correctly respect local versions when forking: ``` // pyproject.toml dependencies = [ "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", "torch==2.0.0+cpu ; platform_machine != 'x86_64'" ] // requirements.in torch==2.0.0 . ``` We should also be able to use a similar strategy for https://github.com/astral-sh/uv/pull/5150. ## Test Plan This fixes https://github.com/astral-sh/uv/issues/5220 locally for me, as well as a few other bugs that were not reported yet. --- .../uv-resolver/src/pubgrub/dependencies.rs | 49 +- crates/uv-resolver/src/resolver/locals.rs | 90 ++- crates/uv-resolver/src/resolver/mod.rs | 94 +-- crates/uv/tests/pip_compile.rs | 580 +++++++++++++++++- 4 files changed, 718 insertions(+), 95 deletions(-) diff --git a/crates/uv-resolver/src/pubgrub/dependencies.rs b/crates/uv-resolver/src/pubgrub/dependencies.rs index 08dd39e28..97a1a1e4a 100644 --- a/crates/uv-resolver/src/pubgrub/dependencies.rs +++ b/crates/uv-resolver/src/pubgrub/dependencies.rs @@ -12,37 +12,37 @@ use pypi_types::{ use uv_normalize::{ExtraName, PackageName}; use crate::pubgrub::{PubGrubPackage, PubGrubPackageInner}; -use crate::resolver::ForkLocals; use crate::{PubGrubSpecifier, ResolveError}; #[derive(Clone, Debug)] pub(crate) struct PubGrubDependency { pub(crate) package: PubGrubPackage, pub(crate) version: Range, + + /// The original version specifiers from the requirement. + pub(crate) specifier: Option, + /// This field is set if the [`Requirement`] had a URL. We still use a URL from [`Urls`] /// even if this field is None where there is an override with a URL or there is a different /// requirement or constraint for the same package that has a URL. pub(crate) url: Option, - /// The local version for this requirement, if specified. - pub(crate) local: Option, } impl PubGrubDependency { pub(crate) fn from_requirement<'a>( requirement: &'a Requirement, source_name: Option<&'a PackageName>, - fork_locals: &'a ForkLocals, ) -> impl Iterator> + 'a { // Add the package, plus any extra variants. iter::once(None) .chain(requirement.extras.clone().into_iter().map(Some)) - .map(|extra| PubGrubRequirement::from_requirement(requirement, extra, fork_locals)) + .map(|extra| PubGrubRequirement::from_requirement(requirement, extra)) .filter_map_ok(move |requirement| { let PubGrubRequirement { package, version, + specifier, url, - local, } = requirement; match &*package { PubGrubPackageInner::Package { name, .. } => { @@ -55,15 +55,15 @@ impl PubGrubDependency { Some(PubGrubDependency { package: package.clone(), version: version.clone(), + specifier, url, - local, }) } PubGrubPackageInner::Marker { .. } => Some(PubGrubDependency { package: package.clone(), version: version.clone(), + specifier, url, - local, }), PubGrubPackageInner::Extra { name, .. } => { debug_assert!( @@ -73,8 +73,8 @@ impl PubGrubDependency { Some(PubGrubDependency { package: package.clone(), version: version.clone(), + specifier, url: None, - local: None, }) } _ => None, @@ -88,8 +88,8 @@ impl PubGrubDependency { pub(crate) struct PubGrubRequirement { pub(crate) package: PubGrubPackage, pub(crate) version: Range, + pub(crate) specifier: Option, pub(crate) url: Option, - pub(crate) local: Option, } impl PubGrubRequirement { @@ -98,11 +98,10 @@ impl PubGrubRequirement { pub(crate) fn from_requirement( requirement: &Requirement, extra: Option, - fork_locals: &ForkLocals, ) -> Result { let (verbatim_url, parsed_url) = match &requirement.source { RequirementSource::Registry { specifier, .. } => { - return Self::from_registry_requirement(specifier, extra, requirement, fork_locals); + return Self::from_registry_requirement(specifier, extra, requirement); } RequirementSource::Url { subdirectory, @@ -165,11 +164,11 @@ impl PubGrubRequirement { requirement.marker.clone(), ), version: Range::full(), + specifier: None, url: Some(VerbatimParsedUrl { parsed_url, verbatim: verbatim_url.clone(), }), - local: None, }) } @@ -177,26 +176,8 @@ impl PubGrubRequirement { specifier: &VersionSpecifiers, extra: Option, requirement: &Requirement, - fork_locals: &ForkLocals, ) -> Result { - // If the specifier is an exact version and the user requested a local version for this - // fork that's more precise than the specifier, use the local version instead. - let version = if let Some(local) = fork_locals.get(&requirement.name) { - specifier - .iter() - .map(|specifier| { - ForkLocals::map(local, specifier) - .map_err(ResolveError::InvalidVersion) - .and_then(|specifier| { - Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?) - }) - }) - .fold_ok(Range::full(), |range, specifier| { - range.intersection(&specifier.into()) - })? - } else { - PubGrubSpecifier::from_pep440_specifiers(specifier)?.into() - }; + let version = PubGrubSpecifier::from_pep440_specifiers(specifier)?.into(); let requirement = Self { package: PubGrubPackage::from_package( @@ -204,9 +185,9 @@ impl PubGrubRequirement { extra, requirement.marker.clone(), ), - version, + specifier: Some(specifier.clone()), url: None, - local: None, + version, }; Ok(requirement) diff --git a/crates/uv-resolver/src/resolver/locals.rs b/crates/uv-resolver/src/resolver/locals.rs index c078118b1..7e1a09a76 100644 --- a/crates/uv-resolver/src/resolver/locals.rs +++ b/crates/uv-resolver/src/resolver/locals.rs @@ -3,24 +3,74 @@ use std::str::FromStr; use distribution_filename::{SourceDistFilename, WheelFilename}; use distribution_types::RemoteSource; use pep440_rs::{Operator, Version, VersionSpecifier, VersionSpecifierBuildError}; -use pep508_rs::PackageName; +use pep508_rs::{MarkerEnvironment, MarkerTree, PackageName}; use pypi_types::RequirementSource; use rustc_hash::FxHashMap; -/// A map of package names to their associated, required local versions in a given fork. -#[derive(Debug, Default, Clone)] -pub(crate) struct ForkLocals(FxHashMap); +use crate::{marker::is_disjoint, DependencyMode, Manifest, ResolverMarkers}; -impl ForkLocals { - /// Insert the local [`Version`] to which a package is pinned for this fork. - pub(crate) fn insert(&mut self, package_name: PackageName, local: Version) { - assert!(local.is_local()); - self.0.insert(package_name, local); +/// A map of package names to their associated, required local versions across all forks. +#[derive(Debug, Default, Clone)] +pub(crate) struct Locals(FxHashMap, Version)>>); + +impl Locals { + /// Determine the set of permitted local versions in the [`Manifest`]. + pub(crate) fn from_manifest( + manifest: &Manifest, + markers: Option<&MarkerEnvironment>, + dependencies: DependencyMode, + ) -> Self { + let mut required: FxHashMap> = FxHashMap::default(); + + // Add all direct requirements and constraints. There's no need to look for conflicts, + // since conflicts will be enforced by the solver. + for requirement in manifest.requirements(markers, dependencies) { + if let Some(local) = from_source(&requirement.source) { + required + .entry(requirement.name.clone()) + .or_default() + .push((requirement.marker.clone(), local)); + } + } + + Self(required) } - /// Return the local [`Version`] to which a package is pinned in this fork, if any. - pub(crate) fn get(&self, package_name: &PackageName) -> Option<&Version> { - self.0.get(package_name) + /// Return a list of local versions that are compatible with a package in the given fork. + pub(crate) fn get( + &self, + package_name: &PackageName, + markers: &ResolverMarkers, + ) -> Vec<&Version> { + let Some(locals) = self.0.get(package_name) else { + return Vec::new(); + }; + + match markers { + // If we are solving for a specific environment we already filtered + // compatible requirements `from_manifest`. + ResolverMarkers::SpecificEnvironment(_) => { + locals.first().map(|(_, local)| local).into_iter().collect() + } + + // Return all locals that were requested with markers that are compatible + // with the current fork. + // + // Compatibility implies that the markers are not disjoint. The resolver will + // choose the most compatible local when it narrows to the specific fork. + ResolverMarkers::Fork(fork) => locals + .iter() + .filter(|(marker, _)| { + !marker + .as_ref() + .is_some_and(|marker| is_disjoint(fork, marker)) + }) + .map(|(_, local)| local) + .collect(), + + // If we haven't forked yet, all locals are potentially compatible. + ResolverMarkers::Universal => locals.iter().map(|(_, local)| local).collect(), + } } /// Given a specifier that may include the version _without_ a local segment, return a specifier @@ -190,7 +240,7 @@ mod tests { use pypi_types::ParsedUrl; use pypi_types::RequirementSource; - use super::{from_source, ForkLocals}; + use super::{from_source, Locals}; #[test] fn extract_locals() -> Result<()> { @@ -251,7 +301,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -260,7 +310,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::NotEqual, Version::from_str("1.0.0+local")?)? ); @@ -269,7 +319,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::LessThanEqual, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -278,7 +328,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::GreaterThan, Version::from_str("1.0.0")?)? ); @@ -287,7 +337,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::ExactEqual, Version::from_str("1.0.0")?)? ); @@ -296,7 +346,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+local")?)? ); @@ -305,7 +355,7 @@ mod tests { let specifier = VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)?; assert_eq!( - ForkLocals::map(&local, &specifier)?, + Locals::map(&local, &specifier)?, VersionSpecifier::from_version(Operator::Equal, Version::from_str("1.0.0+other")?)? ); diff --git a/crates/uv-resolver/src/resolver/mod.rs b/crates/uv-resolver/src/resolver/mod.rs index e8dff8de3..97f6c5bfe 100644 --- a/crates/uv-resolver/src/resolver/mod.rs +++ b/crates/uv-resolver/src/resolver/mod.rs @@ -26,7 +26,7 @@ use distribution_types::{ IncompatibleWheel, IndexLocations, InstalledDist, PythonRequirementKind, RemoteSource, ResolvedDist, ResolvedDistRef, SourceDist, VersionOrUrlRef, }; -pub(crate) use locals::ForkLocals; +pub(crate) use locals::Locals; use pep440_rs::{Version, MIN_VERSION}; use pep508_rs::MarkerTree; use platform_tags::Tags; @@ -96,6 +96,7 @@ struct ResolverState { git: GitResolver, exclusions: Exclusions, urls: Urls, + locals: Locals, dependency_mode: DependencyMode, hasher: HashStrategy, markers: ResolverMarkers, @@ -215,6 +216,11 @@ impl git, options.dependency_mode, )?, + locals: Locals::from_manifest( + &manifest, + markers.marker_environment(), + options.dependency_mode, + ), project: manifest.project, requirements: manifest.requirements, constraints: manifest.constraints, @@ -309,7 +315,6 @@ impl ResolverState ResolverState ResolverState ResolverState ResolverState ResolverState, ) -> Result { - let result = self.get_dependencies( - package, - version, - fork_urls, - fork_locals, - markers, - requires_python, - ); + let result = self.get_dependencies(package, version, fork_urls, markers, requires_python); match markers { ResolverMarkers::SpecificEnvironment(_) => result.map(|deps| match deps { Dependencies::Available(deps) => ForkedDependencies::Unforked(deps), @@ -1128,7 +1126,6 @@ impl ResolverState, ) -> Result { @@ -1148,14 +1145,7 @@ impl ResolverState, _>>()? } PubGrubPackageInner::Package { @@ -1283,7 +1273,7 @@ impl ResolverState, _>>()?; @@ -1302,8 +1292,8 @@ impl ResolverState ResolverState ResolverState ResolverState, version: &Version, urls: &Urls, - dependencies: Vec, + locals: &Locals, + mut dependencies: Vec, git: &GitResolver, resolution_strategy: &ResolutionStrategy, ) -> Result<(), ResolveError> { - for dependency in &dependencies { + for dependency in &mut dependencies { let PubGrubDependency { package, version, + specifier, url, - local, } = dependency; let mut has_url = false; @@ -2057,11 +2046,36 @@ impl ForkState { has_url = true; }; - // `PubGrubDependency` also gives us a local version if specified by the user. - // Keep track of which local version we will be using in this fork for transitive - // dependencies. - if let Some(local) = local { - self.fork_locals.insert(name.clone(), local.clone()); + // If the specifier is an exact version and the user requested a local version for this + // fork that's more precise than the specifier, use the local version instead. + if let Some(specifier) = specifier { + let locals = locals.get(name, &self.markers); + + // Prioritize local versions over the original version range. + if !locals.is_empty() { + *version = Range::empty(); + } + + // It's possible that there are multiple matching local versions requested with + // different marker expressions. All of these are potentially compatible until we + // narrow to a specific fork. + for local in locals { + let local = specifier + .iter() + .map(|specifier| { + Locals::map(local, specifier) + .map_err(ResolveError::InvalidVersion) + .and_then(|specifier| { + Ok(PubGrubSpecifier::from_pep440_specifier(&specifier)?) + }) + }) + .fold_ok(Range::full(), |range, specifier| { + range.intersection(&specifier.into()) + })?; + + // Add the local version. + *version = version.union(&local); + } } } @@ -2096,8 +2110,8 @@ impl ForkState { let PubGrubDependency { package, version, + specifier: _, url: _, - local: _, } = dependency; (package, version) }), diff --git a/crates/uv/tests/pip_compile.rs b/crates/uv/tests/pip_compile.rs index 683da1296..14a5b4f6f 100644 --- a/crates/uv/tests/pip_compile.rs +++ b/crates/uv/tests/pip_compile.rs @@ -6708,6 +6708,7 @@ fn universal_multi_version() -> Result<()> { Ok(()) } +// Requested distinct local versions with disjoint markers. #[test] fn universal_disjoint_locals() -> Result<()> { let context = TestContext::new("3.12"); @@ -6764,6 +6765,8 @@ fn universal_disjoint_locals() -> Result<()> { Ok(()) } +// Requested distinct local versions with disjoint markers of a package +// that is also present as a transitive dependency. #[test] fn universal_transitive_disjoint_locals() -> Result<()> { let context = TestContext::new("3.12"); @@ -6776,7 +6779,7 @@ fn universal_transitive_disjoint_locals() -> Result<()> { torchvision==0.15.1 "})?; - // The marker expressions on the output here are incorrect due to https://github.com/astral-sh/uv/issues/5086, + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, // but the local versions are still respected correctly. uv_snapshot!(context.filters(), windows_filters=false, context.pip_compile() .arg("requirements.in") @@ -6842,6 +6845,581 @@ fn universal_transitive_disjoint_locals() -> Result<()> { Ok(()) } +/// Prefer local versions for dependencies of path requirements. +#[test] +fn universal_local_path_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 12 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests a local version with an overlapping marker expression, +/// we should prefer the local in all cases. +#[test] +fn universal_overlapping_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 12 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests distinct local versions with disjoint marker expressions, +/// we should fork the root requirement. +#[test] +fn universal_disjoint_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", + "torch==2.0.0+cpu ; platform_machine != 'x86_64'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, + // but the local versions are still respected correctly. + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cpu + # via + # -r requirements.in + # example + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 13 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests distinct local versions and non-local versions with disjoint marker +/// expressions, we should fork the root requirement. +#[test] +fn universal_disjoint_base_or_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0; python_version < '3.10'", + "torch==2.0.0+cu118 ; python_version >= '3.10' and python_version <= '3.12'", + "torch==2.0.0+cpu ; python_version > '3.12'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 + . + "})?; + + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, + // but the local versions are still respected correctly. + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + torch==2.0.0+cpu + # via + # -r requirements.in + # example + torch==2.0.0+cu118 + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 13 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests a local version with an overlapping marker expression +/// that form a nested fork, we should prefer the local in both children of the outer +/// fork. +#[test] +fn universal_nested_overlapping_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64' and os_name == 'Linux'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 ; platform_machine == 'x86_64' + torch==2.3.0 ; platform_machine != 'x86_64' + . + "})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + fsspec==2024.3.1 ; platform_machine != 'x86_64' + # via torch + intel-openmp==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mkl==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via torch + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + torch==2.3.0 ; platform_machine != 'x86_64' + # via -r requirements.in + torch==2.0.0+cu118 ; platform_machine == 'x86_64' + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 17 packages in [TIME] + "### + ); + + // A similar case, except the nested marker is now on the path requirement. + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 ; platform_machine == 'x86_64' + torch==2.3.0 ; platform_machine != 'x86_64' + . ; os_name == 'Linux' + "})?; + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", + ] + requires-python = ">=3.11" + "#})?; + + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . ; os_name == 'Linux' + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + fsspec==2024.3.1 ; platform_machine != 'x86_64' + # via torch + intel-openmp==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + jinja2==3.1.3 + # via torch + lit==18.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mkl==2021.4.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via torch + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + sympy==1.12 + # via torch + tbb==2021.11.0 ; platform_machine != 'x86_64' and platform_system == 'Windows' + # via mkl + torch==2.3.0 ; platform_machine != 'x86_64' + # via -r requirements.in + torch==2.0.0+cu118 ; platform_machine == 'x86_64' + # via + # -r requirements.in + # example + # triton + triton==2.0.0 ; platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 17 packages in [TIME] + "### + ); + + Ok(()) +} + +/// If a dependency requests distinct local versions with disjoint marker expressions +/// that form a nested fork, we should create a nested fork. +#[test] +fn universal_nested_disjoint_local_requirement() -> Result<()> { + let context = TestContext::new("3.12"); + + let pyproject_toml = context.temp_dir.child("pyproject.toml"); + pyproject_toml.write_str(indoc! {r#" + [project] + name = "example" + version = "0.0.0" + dependencies = [ + "torch==2.0.0+cu118 ; platform_machine == 'x86_64'", + "torch==2.0.0+cpu ; platform_machine != 'x86_64'" + ] + requires-python = ">=3.11" + "#})?; + + let requirements_in = context.temp_dir.child("requirements.in"); + requirements_in.write_str(indoc! {" + torch==2.0.0 ; os_name == 'Linux' + torch==2.3.0 ; os_name != 'Linux' + . ; os_name == 'Linux' + "})?; + + // Some marker expressions on the output here are missing due to https://github.com/astral-sh/uv/issues/5086, + // but the local versions are still respected correctly. + uv_snapshot!(context.pip_compile() + .arg("requirements.in") + .arg("--universal") + .arg("--find-links") + .arg("https://download.pytorch.org/whl/torch_stable.html"), @r###" + success: true + exit_code: 0 + ----- stdout ----- + # This file was autogenerated by uv via the following command: + # uv pip compile --cache-dir [CACHE_DIR] requirements.in --universal + cmake==3.28.4 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + . ; os_name == 'Linux' + # via -r requirements.in + filelock==3.13.1 + # via + # torch + # triton + fsspec==2024.3.1 ; os_name != 'Linux' + # via torch + intel-openmp==2021.4.0 ; os_name != 'Linux' and platform_system == 'Windows' + # via mkl + jinja2==3.1.3 + # via torch + lit==18.1.2 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via triton + markupsafe==2.1.5 + # via jinja2 + mkl==2021.4.0 ; os_name != 'Linux' and platform_system == 'Windows' + # via torch + mpmath==1.3.0 + # via sympy + networkx==3.2.1 + # via torch + nvidia-cublas-cu12==12.1.3.1 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch + nvidia-cuda-cupti-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cuda-nvrtc-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cuda-runtime-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cudnn-cu12==8.9.2.26 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cufft-cu12==11.0.2.54 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-curand-cu12==10.3.2.106 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cusolver-cu12==11.4.5.107 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-cusparse-cu12==12.1.0.106 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via + # nvidia-cusolver-cu12 + # torch + nvidia-nccl-cu12==2.20.5 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + nvidia-nvjitlink-cu12==12.4.99 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + nvidia-nvtx-cu12==12.1.105 ; os_name != 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + sympy==1.12 + # via torch + tbb==2021.11.0 ; os_name != 'Linux' and platform_system == 'Windows' + # via mkl + torch==2.0.0+cu118 ; os_name == 'Linux' + # via + # -r requirements.in + # example + # triton + torch==2.3.0 ; os_name != 'Linux' + # via -r requirements.in + torch==2.0.0+cpu ; os_name == 'Linux' + # via + # -r requirements.in + # example + triton==2.0.0 ; os_name == 'Linux' and platform_machine == 'x86_64' and platform_system == 'Linux' + # via torch + typing-extensions==4.10.0 + # via torch + + ----- stderr ----- + Resolved 30 packages in [TIME] + "### + ); + + Ok(()) +} + /// Perform a universal resolution that requires narrowing the supported Python range in one of the /// fork branches. ///