diff --git a/crates/pep440-rs/src/version_specifier.rs b/crates/pep440-rs/src/version_specifier.rs index 4d6f94886..87d22eb67 100644 --- a/crates/pep440-rs/src/version_specifier.rs +++ b/crates/pep440-rs/src/version_specifier.rs @@ -1,7 +1,9 @@ #[cfg(feature = "pyo3")] use std::hash::{Hash, Hasher}; + +use std::cmp::Ordering; use std::ops::Bound; -use std::{cmp::Ordering, str::FromStr}; +use std::str::FromStr; #[cfg(feature = "pyo3")] use pyo3::{ diff --git a/crates/pep508-rs/src/marker.rs b/crates/pep508-rs/src/marker.rs index 379eff709..8ffa28015 100644 --- a/crates/pep508-rs/src/marker.rs +++ b/crates/pep508-rs/src/marker.rs @@ -1881,30 +1881,6 @@ impl MarkerTree { exprs.push(tree); } } - - /// Normalizes this marker tree such that all conjunctions and disjunctions - /// are sorted. - /// - /// This is useful in cases where creating conjunctions or disjunctions - /// might occur in a non-deterministic order. This routine will erase the - /// distinction created by such a construction. - pub fn normalize(&mut self) { - match *self { - MarkerTree::Expression(_) => {} - MarkerTree::And(ref mut trees) | MarkerTree::Or(ref mut trees) => { - // This is kind of cheesy, because we're doing a recursive call - // followed by a sort, and that sort is also recursive (due to - // the corresponding Ord impl being recursive). - // - // We should consider refactoring `MarkerTree` to a "smart - // constructor" design that normalizes them by construction. - for tree in &mut *trees { - tree.normalize(); - } - trees.sort(); - } - } - } } impl Display for MarkerTree { diff --git a/crates/uv-resolver/src/lock.rs b/crates/uv-resolver/src/lock.rs index ebfa6398c..e11d3303c 100644 --- a/crates/uv-resolver/src/lock.rs +++ b/crates/uv-resolver/src/lock.rs @@ -511,7 +511,7 @@ impl Distribution { // Markers can be combined in an unpredictable order, so normalize them // such that the lock file output is consistent and deterministic. if let Some(ref mut marker) = marker { - marker.normalize(); + crate::marker::normalize(marker); } let sdist = SourceDist::from_annotated_dist(annotated_dist)?; let wheels = Wheel::from_annotated_dist(annotated_dist)?; diff --git a/crates/uv-resolver/src/marker.rs b/crates/uv-resolver/src/marker.rs index 27c1fa74b..94fe1db2f 100644 --- a/crates/uv-resolver/src/marker.rs +++ b/crates/uv-resolver/src/marker.rs @@ -1,5 +1,7 @@ #![allow(clippy::enum_glob_use)] +use std::collections::HashMap; +use std::mem; use std::ops::Bound::{self, *}; use std::ops::RangeBounds; @@ -10,6 +12,7 @@ use pep508_rs::{ }; use crate::pubgrub::PubGrubSpecifier; +use pubgrub::range::Range as PubGrubRange; /// Returns `true` if there is no environment in which both marker trees can both apply, i.e. /// the expression `first and second` is always false. @@ -79,6 +82,111 @@ fn string_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool true } +/// Normalizes this marker tree. +/// +/// This function does a number of operations to normalize a marker tree recursively: +/// - Sort all nested expressions. +/// - Simplify expressions. This includes combining overlapping version ranges and removing duplicate +/// expressions at the same level of precedence. For example, `(a == 'a' and a == 'a') or b == 'b'` can +/// be reduced, but `a == 'a' and (a == 'a' or b == 'b')` cannot. +/// - Normalize the order of version expressions to the form ` ` +/// (i.e. not the reverse). +/// +/// This is useful in cases where creating conjunctions or disjunctions might occur in a non-deterministic +/// order. This routine will attempt to erase the distinction created by such a construction. +pub(crate) fn normalize(tree: &mut MarkerTree) { + match tree { + MarkerTree::And(trees) | MarkerTree::Or(trees) => { + let mut reduced = Vec::new(); + let mut versions: HashMap<_, Vec<_>> = HashMap::new(); + + for mut tree in mem::take(trees) { + // Simplify nested expressions as much as possible first. + normalize(&mut tree); + + // Extract expressions we may be able to simplify more. + if let MarkerTree::Expression(ref expr) = tree { + if let Some((key, range)) = keyed_range(expr) { + versions.entry(key.clone()).or_default().push(range); + continue; + } + } + + reduced.push(tree); + } + + match tree { + MarkerTree::And(_) => { + simplify_ranges(&mut reduced, versions, |ranges| { + ranges + .iter() + .fold(PubGrubRange::full(), |acc, range| acc.intersection(range)) + }); + + reduced.dedup(); + reduced.sort(); + + *tree = match reduced.len() { + 1 => reduced.remove(0), + _ => MarkerTree::And(reduced), + }; + } + MarkerTree::Or(_) => { + simplify_ranges(&mut reduced, versions, |ranges| { + ranges + .iter() + .fold(PubGrubRange::empty(), |acc, range| acc.union(range)) + }); + + reduced.dedup(); + reduced.sort(); + + *tree = match reduced.len() { + 1 => reduced.remove(0), + _ => MarkerTree::Or(reduced), + }; + } + MarkerTree::Expression(_) => unreachable!(), + } + } + MarkerTree::Expression(_) => {} + } +} + +// Simplify version expressions. +fn simplify_ranges( + reduced: &mut Vec, + versions: HashMap>>, + combine: impl Fn(&Vec>) -> PubGrubRange, +) { + for (key, ranges) in versions { + let simplified = combine(&ranges); + + // If this is a meaningless expressions with no valid intersection, add back + // the original ranges. + if simplified.is_empty() { + for specifier in ranges + .iter() + .flat_map(PubGrubRange::iter) + .flat_map(VersionSpecifier::from_bounds) + { + reduced.push(MarkerTree::Expression(MarkerExpression::Version { + specifier, + key: key.clone(), + })); + } + } + + // Add back the simplified segments. + for specifier in simplified.iter().flat_map(VersionSpecifier::from_bounds) { + reduced.push(MarkerTree::Expression(MarkerExpression::Version { + key: key.clone(), + specifier, + })); + } + } +} + /// Extracts the key, value, and string from a string expression, reversing the operator if necessary. fn extract_string_expression( expr: &MarkerExpression, @@ -145,12 +253,12 @@ fn extra_is_disjoint(operator: &ExtraOperator, name: &ExtraName, other: &MarkerE /// Returns `true` if this version expression does not intersect with the given expression. fn version_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool { - let Some((key, range)) = keyed_range(this).unwrap() else { + let Some((key, range)) = keyed_range(this) else { return false; }; // if this is not a version expression it may intersect - let Ok(Some((key2, range2))) = keyed_range(other) else { + let Some((key2, range2)) = keyed_range(other) else { return false; }; @@ -164,9 +272,7 @@ fn version_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> boo } /// Returns the key and version range for a version expression. -fn keyed_range( - expr: &MarkerExpression, -) -> Result)>, ()> { +fn keyed_range(expr: &MarkerExpression) -> Option<(&MarkerValueVersion, PubGrubRange)> { let (key, specifier) = match expr { MarkerExpression::Version { key, specifier } => (key, specifier.clone()), MarkerExpression::VersionInverted { @@ -178,19 +284,19 @@ fn keyed_range( // a version specifier let operator = reverse_operator(*operator); let Ok(specifier) = VersionSpecifier::from_version(operator, version.clone()) else { - return Ok(None); + return None; }; (key, specifier) } - _ => return Err(()), + _ => return None, }; let Ok(pubgrub_specifier) = PubGrubSpecifier::try_from(&specifier) else { - return Ok(None); + return None; }; - Ok(Some((key, pubgrub_specifier.into()))) + Some((key, pubgrub_specifier.into())) } /// Reverses a binary operator. @@ -223,14 +329,85 @@ mod tests { use super::*; - fn is_disjoint(one: impl AsRef, two: impl AsRef) -> bool { - let one = MarkerTree::parse_reporter(one.as_ref(), &mut TracingReporter).unwrap(); - let two = MarkerTree::parse_reporter(two.as_ref(), &mut TracingReporter).unwrap(); - super::is_disjoint(&one, &two) && super::is_disjoint(&two, &one) + #[test] + fn simplify() { + assert_marker_equal( + "python_version == '3.1' or python_version == '3.1'", + "python_version == '3.1'", + ); + + assert_marker_equal( + "python_version < '3.17' or python_version < '3.18'", + "python_version < '3.18'", + ); + + assert_marker_equal( + "python_version > '3.17' or python_version > '3.18' or python_version > '3.12'", + "python_version > '3.12'", + ); + + // a quirk of how pubgrub works, but this is considered part of normalization + assert_marker_equal( + "python_version > '3.17.post4' or python_version > '3.18.post4'", + "python_version >= '3.17.post5'", + ); + + assert_marker_equal( + "python_version < '3.17' and python_version < '3.18'", + "python_version < '3.17'", + ); + + assert_marker_equal( + "python_version <= '3.18' and python_version == '3.18'", + "python_version == '3.18'", + ); + + assert_marker_equal( + "python_version <= '3.18' or python_version == '3.18'", + "python_version <= '3.18'", + ); + + assert_marker_equal( + "python_version <= '3.15' or (python_version <= '3.17' and python_version < '3.16')", + "python_version < '3.16'", + ); + + assert_marker_equal( + "(python_version > '3.17' or python_version > '3.16') and python_version > '3.15'", + "python_version > '3.16'", + ); + + assert_marker_equal( + "(python_version > '3.17' or python_version > '3.16') and python_version > '3.15' and implementation_version == '1'", + "implementation_version == '1' and python_version > '3.16'", + ); + + assert_marker_equal( + "('3.17' < python_version or '3.16' < python_version) and '3.15' < python_version and implementation_version == '1'", + "implementation_version == '1' and python_version > '3.16'", + ); + + assert_marker_equal("extra == 'a' or extra == 'a'", "extra == 'a'"); + assert_marker_equal( + "extra == 'a' and extra == 'a' or extra == 'b'", + "extra == 'a' or extra == 'b'", + ); + + // bogus expressions are retained but still normalized + assert_marker_equal( + "python_version < '3.17' and '3.18' == python_version", + "python_version == '3.18' and python_version < '3.17'", + ); + + // cannot simplify nested complex expressions + assert_marker_equal( + "extra == 'a' and (extra == 'a' or extra == 'b')", + "extra == 'a' and (extra == 'a' or extra == 'b')", + ); } #[test] - fn extra() { + fn extra_disjointness() { assert!(!is_disjoint("extra == 'a'", "python_version == '1'")); assert!(!is_disjoint("extra == 'a'", "extra == 'a'")); @@ -243,7 +420,7 @@ mod tests { } #[test] - fn arbitrary() { + fn arbitrary_disjointness() { assert!(is_disjoint( "python_version == 'Linux'", "python_version == '3.7.1'" @@ -251,13 +428,13 @@ mod tests { } #[test] - fn version() { + fn version_disjointness() { assert!(!is_disjoint( "os_name == 'Linux'", "python_version == '3.7.1'" )); - test_version_bounds("python_version"); + test_version_bounds_disjointness("python_version"); assert!(!is_disjoint( "python_version == '3.7.*'", @@ -266,7 +443,7 @@ mod tests { } #[test] - fn string() { + fn string_disjointness() { assert!(!is_disjoint( "os_name == 'Linux'", "platform_version == '3.7.1'" @@ -277,7 +454,7 @@ mod tests { )); // basic version bounds checking should still work with lexicographical comparisons - test_version_bounds("platform_version"); + test_version_bounds_disjointness("platform_version"); assert!(is_disjoint("os_name == 'Linux'", "os_name == 'OSX'")); assert!(is_disjoint("os_name <= 'Linux'", "os_name == 'OSX'")); @@ -303,7 +480,7 @@ mod tests { } #[test] - fn combined() { + fn combined_disjointness() { assert!(!is_disjoint( "os_name == 'a' and platform_version == '1'", "os_name == 'a'" @@ -327,7 +504,7 @@ mod tests { )); } - fn test_version_bounds(version: &str) { + fn test_version_bounds_disjointness(version: &str) { assert!(!is_disjoint( format!("{version} > '2.7.0'"), format!("{version} == '3.6.0'") @@ -372,4 +549,17 @@ mod tests { format!("{version} != '3.7.0'") )); } + + fn is_disjoint(one: impl AsRef, two: impl AsRef) -> bool { + let one = MarkerTree::parse_reporter(one.as_ref(), &mut TracingReporter).unwrap(); + let two = MarkerTree::parse_reporter(two.as_ref(), &mut TracingReporter).unwrap(); + super::is_disjoint(&one, &two) && super::is_disjoint(&two, &one) + } + + fn assert_marker_equal(one: impl AsRef, two: impl AsRef) { + let mut tree1 = MarkerTree::parse_reporter(one.as_ref(), &mut TracingReporter).unwrap(); + super::normalize(&mut tree1); + let tree2 = MarkerTree::parse_reporter(two.as_ref(), &mut TracingReporter).unwrap(); + assert_eq!(tree1.to_string(), tree2.to_string()); + } } diff --git a/crates/uv-resolver/src/pubgrub/specifier.rs b/crates/uv-resolver/src/pubgrub/specifier.rs index 52bbaa705..3a5883a92 100644 --- a/crates/uv-resolver/src/pubgrub/specifier.rs +++ b/crates/uv-resolver/src/pubgrub/specifier.rs @@ -22,6 +22,12 @@ impl PubGrubSpecifier { } } +impl From> for PubGrubSpecifier { + fn from(range: Range) -> Self { + PubGrubSpecifier(range) + } +} + impl From for Range { /// Convert a PubGrub specifier to a range of versions. fn from(specifier: PubGrubSpecifier) -> Self { diff --git a/crates/uv/tests/lock.rs b/crates/uv/tests/lock.rs index 47ea2e848..49a1e74b2 100644 --- a/crates/uv/tests/lock.rs +++ b/crates/uv/tests/lock.rs @@ -788,7 +788,7 @@ fn lock_dependency_extra() -> Result<()> { name = "importlib-metadata" version = "7.1.0" source = "registry+https://pypi.org/simple" - marker = "python_version < '3.8' or python_version < '3.10'" + marker = "python_version < '3.10'" sdist = { url = "https://files.pythonhosted.org/packages/a0/fc/c4e6078d21fc4fa56300a241b87eae76766aa380a23fc450fc85bb7bf547/importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2", size = 52120 } wheels = [{ url = "https://files.pythonhosted.org/packages/2d/0a/679461c511447ffaf176567d5c496d1de27cbe34a87df6677d7171b2fbd4/importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570", size = 24409 }] @@ -1473,7 +1473,7 @@ fn lock_requires_python() -> Result<()> { name = "typing-extensions" version = "4.7.1" source = "registry+https://pypi.org/simple" - marker = "python_version < '3.8' or python_version < '3.11'" + marker = "python_version < '3.11'" sdist = { url = "https://files.pythonhosted.org/packages/3c/8b/0111dd7d6c1478bf83baa1cab85c686426c7a6274119aceb2bd9d35395ad/typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2", size = 72876 } wheels = [{ url = "https://files.pythonhosted.org/packages/ec/6b/63cc3df74987c36fe26157ee12e09e8f9db4de771e0f3404263117e75b95/typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36", size = 33232 }] @@ -1625,7 +1625,7 @@ fn lock_requires_python() -> Result<()> { name = "typing-extensions" version = "4.7.1" source = "registry+https://pypi.org/simple" - marker = "python_version < '3.8' or python_version < '3.11'" + marker = "python_version < '3.11'" sdist = { url = "https://files.pythonhosted.org/packages/3c/8b/0111dd7d6c1478bf83baa1cab85c686426c7a6274119aceb2bd9d35395ad/typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2", size = 72876 } wheels = [{ url = "https://files.pythonhosted.org/packages/ec/6b/63cc3df74987c36fe26157ee12e09e8f9db4de771e0f3404263117e75b95/typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36", size = 33232 }] @@ -1782,7 +1782,7 @@ fn lock_requires_python() -> Result<()> { name = "typing-extensions" version = "4.10.0" source = "registry+https://pypi.org/simple" - marker = "python_version < '3.8' or python_version < '3.11'" + marker = "python_version < '3.11'" sdist = { url = "https://files.pythonhosted.org/packages/16/3a/0d26ce356c7465a19c9ea8814b960f8a36c3b0d07c323176620b7b483e44/typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb", size = 77558 } wheels = [{ url = "https://files.pythonhosted.org/packages/f9/de/dc04a3ea60b22624b51c703a84bbe0184abcd1d0b9bc8074b5d6b7ab90bb/typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475", size = 33926 }]