Support extras in `@` requests for tools (#11335)

## Summary

Closes https://github.com/astral-sh/uv/issues/11321.
This commit is contained in:
Charlie Marsh 2025-02-07 21:07:15 -05:00 committed by GitHub
parent 25e7209a33
commit 12e7abe093
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 185 additions and 29 deletions

View File

@ -133,14 +133,15 @@ pub(crate) async fn install(
.unwrap()
}
// Ex) `ruff@0.6.0`
Target::Version(name, ref version) | Target::FromVersion(_, name, ref version) => {
Target::Version(name, ref extras, ref version)
| Target::FromVersion(_, name, ref extras, ref version) => {
if editable {
bail!("`--editable` is only supported for local packages");
}
Requirement {
name: PackageName::from_str(name)?,
extras: vec![],
extras: extras.clone(),
groups: vec![],
marker: MarkerTree::default(),
source: RequirementSource::Registry {
@ -154,14 +155,14 @@ pub(crate) async fn install(
}
}
// Ex) `ruff@latest`
Target::Latest(name) | Target::FromLatest(_, name) => {
Target::Latest(name, ref extras) | Target::FromLatest(_, name, ref extras) => {
if editable {
bail!("`--editable` is only supported for local packages");
}
Requirement {
name: PackageName::from_str(name)?,
extras: vec![],
extras: extras.clone(),
groups: vec![],
marker: MarkerTree::default(),
source: RequirementSource::Registry {

View File

@ -2,7 +2,7 @@ use std::str::FromStr;
use tracing::debug;
use uv_normalize::PackageName;
use uv_normalize::{ExtraName, PackageName};
use uv_pep440::Version;
mod common;
@ -14,20 +14,20 @@ pub(crate) mod uninstall;
pub(crate) mod update_shell;
pub(crate) mod upgrade;
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Target<'a> {
/// e.g., `ruff`
Unspecified(&'a str),
/// e.g., `ruff@0.6.0`
Version(&'a str, Version),
Version(&'a str, Vec<ExtraName>, Version),
/// e.g., `ruff@latest`
Latest(&'a str),
Latest(&'a str, Vec<ExtraName>),
/// e.g., `ruff --from ruff>=0.6.0`
From(&'a str, &'a str),
/// e.g., `ruff --from ruff@0.6.0`
FromVersion(&'a str, &'a str, Version),
FromVersion(&'a str, &'a str, Vec<ExtraName>, Version),
/// e.g., `ruff --from ruff@latest`
FromLatest(&'a str, &'a str),
FromLatest(&'a str, &'a str, Vec<ExtraName>),
}
impl<'a> Target<'a> {
@ -45,19 +45,44 @@ impl<'a> Target<'a> {
return Self::From(target, from);
}
// Split into name and extras (e.g., `flask[dotenv]`).
let (name, extras) = match name.split_once('[') {
Some((name, extras)) => {
let Some(extras) = extras.strip_suffix(']') else {
// e.g., ignore `flask[dotenv`.
debug!("Ignoring invalid extras in `--from`");
return Self::From(target, from);
};
(name, extras)
}
None => (name, ""),
};
// e.g., ignore `git+https://github.com/astral-sh/ruff.git@main`
if PackageName::from_str(name).is_err() {
debug!("Ignoring non-package name `{name}` in `--from`");
return Self::From(target, from);
}
// e.g., ignore `ruff[1.0.0]` or any other invalid extra.
let Ok(extras) = extras
.split(',')
.map(str::trim)
.filter(|extra| !extra.is_empty())
.map(ExtraName::from_str)
.collect::<Result<Vec<_>, _>>()
else {
debug!("Ignoring invalid extras `{extras}` in `--from`");
return Self::From(target, from);
};
match version {
// e.g., `ruff@latest`
"latest" => return Self::FromLatest(target, name),
"latest" => return Self::FromLatest(target, name, extras),
// e.g., `ruff@0.6.0`
version => {
if let Ok(version) = Version::from_str(version) {
return Self::FromVersion(target, name, version);
return Self::FromVersion(target, name, extras, version);
}
}
};
@ -78,19 +103,43 @@ impl<'a> Target<'a> {
return Self::Unspecified(target);
}
// Split into name and extras (e.g., `flask[dotenv]`).
let (name, extras) = match name.split_once('[') {
Some((name, extras)) => {
let Some(extras) = extras.strip_suffix(']') else {
// e.g., ignore `flask[dotenv`.
return Self::Unspecified(name);
};
(name, extras)
}
None => (name, ""),
};
// e.g., ignore `git+https://github.com/astral-sh/ruff.git@main`
if PackageName::from_str(name).is_err() {
debug!("Ignoring non-package name `{name}` in command");
return Self::Unspecified(target);
}
// e.g., ignore `ruff[1.0.0]` or any other invalid extra.
let Ok(extras) = extras
.split(',')
.map(str::trim)
.filter(|extra| !extra.is_empty())
.map(ExtraName::from_str)
.collect::<Result<Vec<_>, _>>()
else {
debug!("Ignoring invalid extras `{extras}` in command");
return Self::Unspecified(target);
};
match version {
// e.g., `ruff@latest`
"latest" => return Self::Latest(name),
"latest" => return Self::Latest(name, extras),
// e.g., `ruff@0.6.0`
version => {
if let Ok(version) = Version::from_str(version) {
return Self::Version(name, version);
return Self::Version(name, extras, version);
}
}
};
@ -104,10 +153,10 @@ impl<'a> Target<'a> {
pub(crate) fn executable(&self) -> &str {
match self {
Self::Unspecified(name) => name,
Self::Version(name, _) => name,
Self::Latest(name) => name,
Self::FromVersion(name, _, _) => name,
Self::FromLatest(name, _) => name,
Self::Version(name, _, _) => name,
Self::Latest(name, _) => name,
Self::FromVersion(name, _, _, _) => name,
Self::FromLatest(name, _, _) => name,
Self::From(name, _) => name,
}
}
@ -116,10 +165,10 @@ impl<'a> Target<'a> {
pub(crate) fn is_python(&self) -> bool {
let name = match self {
Self::Unspecified(name) => name,
Self::Version(name, _) => name,
Self::Latest(name) => name,
Self::FromVersion(_, name, _) => name,
Self::FromLatest(_, name) => name,
Self::Version(name, _, _) => name,
Self::Latest(name, _) => name,
Self::FromVersion(_, name, _, _) => name,
Self::FromLatest(_, name, _) => name,
Self::From(_, name) => name,
};
name.eq_ignore_ascii_case("python") || cfg!(windows) && name.eq_ignore_ascii_case("pythonw")
@ -127,6 +176,52 @@ impl<'a> Target<'a> {
/// Returns `true` if the target is `latest`.
fn is_latest(&self) -> bool {
matches!(self, Self::Latest(_) | Self::FromLatest(_, _))
matches!(self, Self::Latest(..) | Self::FromLatest(..))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_target() {
let target = Target::parse("flask", None);
let expected = Target::Unspecified("flask");
assert_eq!(target, expected);
let target = Target::parse("flask@3.0.0", None);
let expected = Target::Version("flask", vec![], Version::new([3, 0, 0]));
assert_eq!(target, expected);
let target = Target::parse("flask@3.0.0", None);
let expected = Target::Version("flask", vec![], Version::new([3, 0, 0]));
assert_eq!(target, expected);
let target = Target::parse("flask@latest", None);
let expected = Target::Latest("flask", vec![]);
assert_eq!(target, expected);
let target = Target::parse("flask[dotenv]@3.0.0", None);
let expected = Target::Version(
"flask",
vec![ExtraName::from_str("dotenv").unwrap()],
Version::new([3, 0, 0]),
);
assert_eq!(target, expected);
let target = Target::parse("flask[dotenv]@latest", None);
let expected = Target::Latest("flask", vec![ExtraName::from_str("dotenv").unwrap()]);
assert_eq!(target, expected);
// Missing a closing `]`.
let target = Target::parse("flask[dotenv", None);
let expected = Target::Unspecified("flask[dotenv");
assert_eq!(target, expected);
// Too many `]`.
let target = Target::parse("flask[dotenv]]", None);
let expected = Target::Unspecified("flask[dotenv]]");
assert_eq!(target, expected);
}
}

View File

@ -460,13 +460,13 @@ async fn get_or_create_environment(
let python_request = if target.is_python() {
let target_request = match target {
Target::Unspecified(_) => None,
Target::Version(_, version) | Target::FromVersion(_, _, version) => {
Target::Version(_, _, version) | Target::FromVersion(_, _, _, version) => {
Some(PythonRequest::Version(
VersionRequest::from_str(&version.to_string()).map_err(anyhow::Error::from)?,
))
}
// TODO(zanieb): Add `PythonRequest::Latest`
Target::Latest(_) | Target::FromLatest(_, _) => {
Target::Latest(_, _) | Target::FromLatest(_, _, _) => {
return Err(anyhow::anyhow!(
"Requesting the 'latest' Python version is not yet supported"
)
@ -531,9 +531,10 @@ async fn get_or_create_environment(
origin: None,
},
// Ex) `ruff@0.6.0`
Target::Version(name, version) | Target::FromVersion(_, name, version) => Requirement {
Target::Version(name, extras, version)
| Target::FromVersion(_, name, extras, version) => Requirement {
name: PackageName::from_str(name)?,
extras: vec![],
extras: extras.clone(),
groups: vec![],
marker: MarkerTree::default(),
source: RequirementSource::Registry {
@ -546,9 +547,9 @@ async fn get_or_create_environment(
origin: None,
},
// Ex) `ruff@latest`
Target::Latest(name) | Target::FromLatest(_, name) => Requirement {
Target::Latest(name, extras) | Target::FromLatest(_, name, extras) => Requirement {
name: PackageName::from_str(name)?,
extras: vec![],
extras: extras.clone(),
groups: vec![],
marker: MarkerTree::default(),
source: RequirementSource::Registry {

View File

@ -1398,6 +1398,65 @@ fn tool_run_latest() {
"###);
}
#[test]
fn tool_run_latest_extra() {
let context = TestContext::new("3.12").with_filtered_exe_suffix();
let tool_dir = context.temp_dir.child("tools");
let bin_dir = context.temp_dir.child("bin");
uv_snapshot!(context.filters(), context.tool_run()
.arg("flask[dotenv]@latest")
.arg("--version")
.env(EnvVars::UV_TOOL_DIR, tool_dir.as_os_str())
.env(EnvVars::XDG_BIN_HOME, bin_dir.as_os_str()), @r###"
success: true
exit_code: 0
----- stdout -----
Python 3.12.[X]
Flask 3.0.2
Werkzeug 3.0.1
----- stderr -----
Resolved 8 packages in [TIME]
Prepared 8 packages in [TIME]
Installed 8 packages in [TIME]
+ blinker==1.7.0
+ click==8.1.7
+ flask==3.0.2
+ itsdangerous==2.1.2
+ jinja2==3.1.3
+ markupsafe==2.1.5
+ python-dotenv==1.0.1
+ werkzeug==3.0.1
"###);
uv_snapshot!(context.filters(), context.tool_run()
.arg("flask[dotenv]@3.0.0")
.arg("--version")
.env(EnvVars::UV_TOOL_DIR, tool_dir.as_os_str())
.env(EnvVars::XDG_BIN_HOME, bin_dir.as_os_str()), @r###"
success: true
exit_code: 0
----- stdout -----
Python 3.12.[X]
Flask 3.0.0
Werkzeug 3.0.1
----- stderr -----
Resolved 8 packages in [TIME]
Prepared 1 package in [TIME]
Installed 8 packages in [TIME]
+ blinker==1.7.0
+ click==8.1.7
+ flask==3.0.0
+ itsdangerous==2.1.2
+ jinja2==3.1.3
+ markupsafe==2.1.5
+ python-dotenv==1.0.1
+ werkzeug==3.0.1
"###);
}
#[test]
fn tool_run_python() {
let context = TestContext::new("3.12").with_filtered_counts();