From fa53de92232dfa19386873b60dbabc775e5a70a8 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Tue, 23 Apr 2024 14:07:51 -0400 Subject: [PATCH] Avoid Removing Quotes From Requirement Markers (#3214) ## Summary Avoid removing quotes from markers, e.g. `numpy (>=1.19) ; python_version >= "3.7"` should not be rewritten. Fixes https://github.com/astral-sh/uv/issues/2551. This PR also makes fixups a bit more flexible internally for fixes that aren't simple to implement with a pure regex replacement, like this one. https://github.com/astral-sh/uv/pull/1529 fixed a similar problem but the current regex is still not smart enough to avoid all markers completely (like `python_version`). ## Test Plan Added a few unit tests. --- crates/pypi-types/src/lenient_requirement.rs | 92 ++++++++++++++++---- 1 file changed, 74 insertions(+), 18 deletions(-) diff --git a/crates/pypi-types/src/lenient_requirement.rs b/crates/pypi-types/src/lenient_requirement.rs index b5c823037..69f72fb14 100644 --- a/crates/pypi-types/src/lenient_requirement.rs +++ b/crates/pypi-types/src/lenient_requirement.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::str::FromStr; use once_cell::sync::Lazy; @@ -24,46 +25,73 @@ static GREATER_THAN_DEV: Lazy = Lazy::new(|| Regex::new(r">dev").unwrap() /// Ex) `>=9.0.0a1.0` static TRAILING_ZERO: Lazy = Lazy::new(|| Regex::new(r"(\d+(\.\d)*(a|b|rc|post|dev)\d+)\.0").unwrap()); -/// Ex) `>= '2.7'`, `>=3.6'` -static STRAY_QUOTES: Lazy = Lazy::new(|| Regex::new(r#"['"]([*\d])|([*\d])['"]"#).unwrap()); -/// Regex to match the invalid specifier, replacement to fix it and message about was wrong and -/// fixed -static FIXUPS: &[(&Lazy, &str, &str)] = &[ +// Search and replace functions that fix invalid specifiers. +type FixUp = for<'a> fn(&'a str) -> Cow<'a, str>; + +/// A list of fixups with a corresponding message about what was fixed. +static FIXUPS: &[(FixUp, &str)] = &[ // Given `>=7.2.0<8.0.0`, rewrite to `>=7.2.0,<8.0.0`. - (&MISSING_COMMA, r"$1,$2", "inserting missing comma"), + ( + |input| MISSING_COMMA.replace_all(input, r"$1,$2"), + "inserting missing comma", + ), // Given `!=~5.0,>=4.12`, rewrite to `!=5.0.*,>=4.12`. ( - &NOT_EQUAL_TILDE, - r"!=${1}.*", + |input| NOT_EQUAL_TILDE.replace_all(input, r"!=${1}.*"), "replacing invalid tilde with wildcard", ), // Given `>=1.9.*`, rewrite to `>=1.9`. ( - &INVALID_TRAILING_DOT_STAR, - r"${1}${2}", + |input| INVALID_TRAILING_DOT_STAR.replace_all(input, r"${1}${2}"), "removing star after comparison operator other than equal and not equal", ), // Given `!=3.0*`, rewrite to `!=3.0.*`. - (&MISSING_DOT, r"${1}.*", "inserting missing dot"), + ( + |input| MISSING_DOT.replace_all(input, r"${1}.*"), + "inserting missing dot", + ), // Given `>=3.6,`, rewrite to `>=3.6` - (&TRAILING_COMMA, r"${1}", "removing trailing comma"), + ( + |input| TRAILING_COMMA.replace_all(input, r"${1}"), + "removing trailing comma", + ), // Given `>dev`, rewrite to `>0.0.0dev` - (&GREATER_THAN_DEV, r">0.0.0dev", "assuming 0.0.0dev"), + ( + |input| GREATER_THAN_DEV.replace_all(input, r">0.0.0dev"), + "assuming 0.0.0dev", + ), // Given `>=9.0.0a1.0`, rewrite to `>=9.0.0a1` - (&TRAILING_ZERO, r"${1}", "removing trailing zero"), - // Given `>= 2.7'`, rewrite to `>= 2.7` - (&STRAY_QUOTES, r"$1$2", "removing stray quotes"), + ( + |input| TRAILING_ZERO.replace_all(input, r"${1}"), + "removing trailing zero", + ), + (remove_stray_quotes, "removing stray quotes"), ]; +// Given `>= 2.7'`, rewrite to `>= 2.7` +fn remove_stray_quotes(input: &str) -> Cow<'_, str> { + /// Ex) `'>= 2.7'`, `>=3.6'` + static STRAY_QUOTES: Lazy = Lazy::new(|| Regex::new(r#"['"]"#).unwrap()); + + // make sure not to touch markers, which can have quotes (e.g. `python_version >= '3.7'`) + match input.find(';') { + Some(markers) => { + let requirement = STRAY_QUOTES.replace_all(&input[..markers], ""); + format!("{}{}", requirement, &input[markers..]).into() + } + None => STRAY_QUOTES.replace_all(input, ""), + } +} + fn parse_with_fixups>(input: &str, type_name: &str) -> Result { match T::from_str(input) { Ok(requirement) => Ok(requirement), Err(err) => { let mut patched_input = input.to_string(); let mut messages = Vec::new(); - for (matcher, replacement, message) in FIXUPS { - let patched = matcher.replace_all(patched_input.as_ref(), *replacement); + for (fixup, message) in FIXUPS { + let patched = fixup(patched_input.as_ref()); if patched != patched_input { messages.push(*message); @@ -362,4 +390,32 @@ mod tests { let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=9a1").unwrap(); assert_eq!(actual, expected); } + + /// + #[test] + fn stray_quote_preserve_marker() { + let actual: Requirement = + LenientRequirement::from_str("numpy >=1.19; python_version >= \"3.7\"") + .unwrap() + .into(); + let expected: Requirement = + Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap(); + assert_eq!(actual, expected); + + let actual: Requirement = + LenientRequirement::from_str("numpy \">=1.19\"; python_version >= \"3.7\"") + .unwrap() + .into(); + let expected: Requirement = + Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap(); + assert_eq!(actual, expected); + + let actual: Requirement = + LenientRequirement::from_str("'numpy' >=1.19\"; python_version >= \"3.7\"") + .unwrap() + .into(); + let expected: Requirement = + Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap(); + assert_eq!(actual, expected); + } }