From a962a65302f944b406e20f845f741fa7cc2fa6f2 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 15 May 2024 13:01:05 -0400 Subject: [PATCH] Add routine for verifying two marker trees are disjoint (#3583) ## Summary Implements https://github.com/astral-sh/uv/issues/3355. --- crates/pep508-rs/src/marker.rs | 14 +- crates/uv-resolver/src/lib.rs | 1 + crates/uv-resolver/src/marker.rs | 372 +++++++++++++++++++++++++++++++ 3 files changed, 380 insertions(+), 7 deletions(-) create mode 100644 crates/uv-resolver/src/marker.rs diff --git a/crates/pep508-rs/src/marker.rs b/crates/pep508-rs/src/marker.rs index fb6d5e65e..1a28d7ea8 100644 --- a/crates/pep508-rs/src/marker.rs +++ b/crates/pep508-rs/src/marker.rs @@ -214,7 +214,7 @@ impl Display for MarkerValue { } /// How to compare key and value, such as by `==`, `>` or `not in` -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] pub enum MarkerOperator { /// `==` Equal, @@ -238,7 +238,7 @@ pub enum MarkerOperator { impl MarkerOperator { /// Compare two versions, returning None for `in` and `not in` - fn to_pep440_operator(&self) -> Option { + fn to_pep440_operator(self) -> Option { match self { Self::Equal => Some(pep440_rs::Operator::Equal), Self::NotEqual => Some(pep440_rs::Operator::NotEqual), @@ -1017,7 +1017,7 @@ impl MarkerExpression { ); }; - match MarkerExpression::version(key.clone(), operator.clone(), &value, reporter) { + match MarkerExpression::version(key.clone(), operator, &value, reporter) { Some(expr) => expr, None => MarkerExpression::arbitrary( MarkerValue::MarkerEnvVersion(key), @@ -1071,7 +1071,7 @@ impl MarkerExpression { MarkerValue::QuotedString(value) => value, }; - match MarkerExpression::extra(operator.clone(), &value, reporter) { + match MarkerExpression::extra(operator, &value, reporter) { Some(expr) => expr, None => MarkerExpression::arbitrary( MarkerValue::Extra, @@ -1087,7 +1087,7 @@ impl MarkerExpression { MarkerValue::MarkerEnvVersion(key) => { match MarkerExpression::version_inverted( &l_string, - operator.clone(), + operator, key.clone(), reporter, ) { @@ -1107,7 +1107,7 @@ impl MarkerExpression { }, // `'...' == extra` MarkerValue::Extra => { - match MarkerExpression::extra(operator.clone(), &l_string, reporter) { + match MarkerExpression::extra(operator, &l_string, reporter) { Some(expr) => expr, None => MarkerExpression::arbitrary( MarkerValue::QuotedString(l_string), @@ -1267,7 +1267,7 @@ impl MarkerExpression { } }; - match ExtraOperator::from_marker_operator(operator.clone()) { + match ExtraOperator::from_marker_operator(operator) { Some(operator) => Some(MarkerExpression::Extra { operator, name }), None => { reporter.report( diff --git a/crates/uv-resolver/src/lib.rs b/crates/uv-resolver/src/lib.rs index 2e6127f8a..4aca4fb67 100644 --- a/crates/uv-resolver/src/lib.rs +++ b/crates/uv-resolver/src/lib.rs @@ -31,6 +31,7 @@ mod exclusions; mod flat_index; mod lock; mod manifest; +mod marker; mod options; mod pins; mod preferences; diff --git a/crates/uv-resolver/src/marker.rs b/crates/uv-resolver/src/marker.rs new file mode 100644 index 000000000..58cade99d --- /dev/null +++ b/crates/uv-resolver/src/marker.rs @@ -0,0 +1,372 @@ +#![allow(clippy::enum_glob_use)] + +use std::ops::Bound::{self, *}; +use std::ops::RangeBounds; + +use pep440_rs::{Operator, Version, VersionSpecifier}; +use pep508_rs::{ + ExtraName, ExtraOperator, MarkerExpression, MarkerOperator, MarkerTree, MarkerValueString, + MarkerValueVersion, +}; + +use crate::pubgrub::PubGrubSpecifier; + +/// Returns `true` if there is no environment in which both marker trees can both apply, i.e. +/// the expression `first and second` is always false. +#[allow(dead_code)] +pub(crate) fn is_disjoint(first: &MarkerTree, second: &MarkerTree) -> bool { + let (expr1, expr2) = match (first, second) { + (MarkerTree::Expression(expr1), MarkerTree::Expression(expr2)) => (expr1, expr2), + // `And` expressions are disjoint if any clause is disjoint. + (other, MarkerTree::And(exprs)) | (MarkerTree::And(exprs), other) => { + return exprs.iter().any(|tree1| is_disjoint(tree1, other)) + } + // `Or` expressions are disjoint if all clauses are disjoint. + (other, MarkerTree::Or(exprs)) | (MarkerTree::Or(exprs), other) => { + return exprs.iter().all(|tree1| is_disjoint(tree1, other)) + } + }; + + match expr1 { + MarkerExpression::Version { .. } | MarkerExpression::VersionInverted { .. } => { + version_is_disjoint(expr1, expr2) + } + MarkerExpression::String { .. } | MarkerExpression::StringInverted { .. } => { + string_is_disjoint(expr1, expr2) + } + MarkerExpression::Extra { operator, name } => extra_is_disjoint(operator, name, expr2), + MarkerExpression::Arbitrary { .. } => false, + } +} + +/// Returns `true` if this string expression does not intersect with the given expression. +fn string_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool { + use MarkerOperator::*; + + let (key, operator, value) = extract_string_expression(this).unwrap(); + let Some((key2, operator2, value2)) = extract_string_expression(other) else { + return false; + }; + + // distinct string expressions are not disjoint + if key != key2 { + return false; + } + + match (operator, operator2) { + // the only disjoint expressions involving strict inequality are `key != value` and `key == value` + (NotEqual, Equal) | (Equal, NotEqual) => return value == value2, + (NotEqual, _) | (_, NotEqual) => return false, + // similarly for `in` and `not in` + (In, NotIn) | (NotIn, In) => return value == value2, + (In | NotIn, _) | (_, In | NotIn) => return false, + _ => {} + } + + let bounds = string_bounds(value, operator); + let bounds2 = string_bounds(value2, operator2); + + // make sure the ranges do not intersection + if range_exists::<&str>(&bounds2.start_bound(), &bounds.end_bound()) + && range_exists::<&str>(&bounds.start_bound(), &bounds2.end_bound()) + { + return false; + } + + true +} + +/// Extracts the key, value, and string from a string expression, reversing the operator if necessary. +fn extract_string_expression( + expr: &MarkerExpression, +) -> Option<(&MarkerValueString, MarkerOperator, &str)> { + match expr { + MarkerExpression::String { + key, + operator, + value, + } => Some((key, *operator, value)), + MarkerExpression::StringInverted { + value, + operator, + key, + } => { + // if the expression was inverted, we have to reverse the operator + Some((key, reverse_marker_operator(*operator), value)) + } + _ => None, + } +} + +/// Returns `true` if the range formed by an upper and lower bound is non-empty. +fn range_exists(lower: &Bound, upper: &Bound) -> bool { + match (lower, upper) { + (Included(s), Included(e)) => s <= e, + (Included(s), Excluded(e)) => s < e, + (Excluded(s), Included(e)) => s < e, + (Excluded(s), Excluded(e)) => s < e, + (Unbounded, _) | (_, Unbounded) => true, + } +} + +/// Returns the lower and upper bounds of a string inequality. +/// +/// Panics if called on the `!=`, `in`, or `not in` operators. +fn string_bounds(value: &str, operator: MarkerOperator) -> (Bound<&str>, Bound<&str>) { + use MarkerOperator::*; + match operator { + Equal => (Included(value), Included(value)), + // TODO: not really sure what this means for strings + TildeEqual => (Included(value), Included(value)), + GreaterThan => (Excluded(value), Unbounded), + GreaterEqual => (Included(value), Unbounded), + LessThan => (Unbounded, Excluded(value)), + LessEqual => (Unbounded, Included(value)), + NotEqual | In | NotIn => unreachable!(), + } +} + +/// Returns `true` if this extra expression does not intersect with the given expression. +fn extra_is_disjoint(operator: &ExtraOperator, name: &ExtraName, other: &MarkerExpression) -> bool { + let MarkerExpression::Extra { + operator: operator2, + name: name2, + } = other + else { + return false; + }; + + // extra expressions are only disjoint if they require existence and non-existence of the same extra + operator != operator2 && name == name2 +} + +/// Returns `true` if this version expression does not intersect with the given expression. +fn version_is_disjoint(this: &MarkerExpression, other: &MarkerExpression) -> bool { + let Some((key, range)) = keyed_range(this).unwrap() else { + return false; + }; + + // if this is not a version expression it may intersect + let Ok(Some((key2, range2))) = keyed_range(other) else { + return false; + }; + + // distinct version expressions are not disjoint + if key != key2 { + return false; + } + + // there is no version that is contained in both ranges + range.is_disjoint(&range2) +} + +/// Returns the key and version range for a version expression. +fn keyed_range( + expr: &MarkerExpression, +) -> Result)>, ()> { + let (key, specifier) = match expr { + MarkerExpression::Version { key, specifier } => (key, specifier.clone()), + MarkerExpression::VersionInverted { + version, + operator, + key, + } => { + // if the expression was inverted, we have to reverse the operator before constructing + // a version specifier + let operator = reverse_operator(*operator); + let Ok(specifier) = VersionSpecifier::from_version(operator, version.clone()) else { + return Ok(None); + }; + + (key, specifier) + } + _ => return Err(()), + }; + + let Ok(pubgrub_specifier) = PubGrubSpecifier::try_from(&specifier) else { + return Ok(None); + }; + + Ok(Some((key, pubgrub_specifier.into()))) +} + +/// Reverses a binary operator. +fn reverse_operator(operator: Operator) -> Operator { + use Operator::*; + match operator { + LessThan => GreaterThan, + LessThanEqual => GreaterThanEqual, + GreaterThan => LessThan, + GreaterThanEqual => LessThanEqual, + _ => operator, + } +} + +/// Reverses a marker operator. +fn reverse_marker_operator(operator: MarkerOperator) -> MarkerOperator { + use MarkerOperator::*; + match operator { + LessThan => GreaterThan, + LessEqual => GreaterEqual, + GreaterThan => LessThan, + GreaterEqual => LessEqual, + _ => operator, + } +} + +#[cfg(test)] +mod tests { + use pep508_rs::TracingReporter; + + use super::*; + + fn is_disjoint(one: impl AsRef, two: impl AsRef) -> bool { + let one = MarkerTree::parse_reporter(one.as_ref(), &mut TracingReporter).unwrap(); + let two = MarkerTree::parse_reporter(two.as_ref(), &mut TracingReporter).unwrap(); + super::is_disjoint(&one, &two) && super::is_disjoint(&two, &one) + } + + #[test] + fn extra() { + assert!(!is_disjoint("extra == 'a'", "python_version == '1'")); + + assert!(!is_disjoint("extra == 'a'", "extra == 'a'")); + assert!(!is_disjoint("extra == 'a'", "extra == 'b'")); + assert!(!is_disjoint("extra == 'b'", "extra == 'a'")); + assert!(!is_disjoint("extra == 'b'", "extra != 'a'")); + assert!(!is_disjoint("extra != 'b'", "extra == 'a'")); + assert!(is_disjoint("extra != 'b'", "extra == 'b'")); + assert!(is_disjoint("extra == 'b'", "extra != 'b'")); + } + + #[test] + fn invalid() { + assert!(!is_disjoint( + "python_version == 'Linux'", + "python_version == '3.7.1'" + )); + } + + #[test] + fn version() { + assert!(!is_disjoint( + "os_name == 'Linux'", + "python_version == '3.7.1'" + )); + + test_version_bounds("python_version"); + + assert!(!is_disjoint( + "python_version == '3.7.*'", + "python_version == '3.7.1'" + )); + } + + #[test] + fn string() { + assert!(!is_disjoint( + "os_name == 'Linux'", + "platform_version == '3.7.1'" + )); + assert!(!is_disjoint( + "implementation_version == '3.7.0'", + "python_version == '3.7.1'" + )); + + // basic version bounds checking should still work with lexicographical comparisons + test_version_bounds("platform_version"); + + assert!(is_disjoint("os_name == 'Linux'", "os_name == 'OSX'")); + assert!(is_disjoint("os_name <= 'Linux'", "os_name == 'OSX'")); + + assert!(!is_disjoint( + "os_name in 'OSXLinuxWindows'", + "os_name == 'OSX'" + )); + assert!(!is_disjoint("'OSX' in os_name", "'Linux' in os_name")); + + // complicated `in` intersections are not supported + assert!(!is_disjoint("os_name in 'OSX'", "os_name in 'Linux'")); + assert!(!is_disjoint( + "os_name in 'OSXLinux'", + "os_name == 'Windows'" + )); + + assert!(is_disjoint( + "os_name in 'Windows'", + "os_name not in 'Windows'" + )); + assert!(is_disjoint("'Linux' in os_name", "os_name not in 'Linux'")); + } + + #[test] + fn combined() { + assert!(!is_disjoint( + "os_name == 'a' and platform_version == '1'", + "os_name == 'a'" + )); + assert!(!is_disjoint( + "os_name == 'a' or platform_version == '1'", + "os_name == 'a'" + )); + + assert!(is_disjoint( + "os_name == 'a' and platform_version == '1'", + "os_name == 'a' and platform_version == '2'" + )); + assert!(is_disjoint( + "os_name == 'a' and platform_version == '1'", + "'2' == platform_version and os_name == 'a'" + )); + assert!(!is_disjoint( + "os_name == 'a' or platform_version == '1'", + "os_name == 'a' or platform_version == '2'" + )); + } + + fn test_version_bounds(version: &str) { + assert!(!is_disjoint( + format!("{version} > '2.7.0'"), + format!("{version} == '3.6.0'") + )); + assert!(!is_disjoint( + format!("{version} >= '3.7.0'"), + format!("{version} == '3.7.1'") + )); + assert!(!is_disjoint( + format!("{version} >= '3.7.0'"), + format!("'3.7.1' == {version}") + )); + + assert!(is_disjoint( + format!("{version} >= '3.7.1'"), + format!("{version} == '3.7.0'") + )); + assert!(is_disjoint( + format!("'3.7.1' <= {version}"), + format!("{version} == '3.7.0'") + )); + + assert!(is_disjoint( + format!("{version} < '3.7.0'"), + format!("{version} == '3.7.0'") + )); + assert!(is_disjoint( + format!("'3.7.0' > {version}"), + format!("{version} == '3.7.0'") + )); + assert!(is_disjoint( + format!("{version} < '3.7.0'"), + format!("{version} == '3.7.1'") + )); + + assert!(is_disjoint( + format!("{version} == '3.7.0'"), + format!("{version} == '3.7.1'") + )); + assert!(is_disjoint( + format!("{version} == '3.7.0'"), + format!("{version} != '3.7.0'") + )); + } +}