Progress bars for `uv publish` (#7613)

This commit is contained in:
konsti 2024-09-24 17:55:33 +02:00 committed by GitHub
parent 1995d20298
commit c053dc84f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 122 additions and 9 deletions

1
Cargo.lock generated
View File

@ -5026,6 +5026,7 @@ dependencies = [
"sha2",
"thiserror",
"tokio",
"tokio-util",
"tracing",
"url",
"uv-client",

View File

@ -1,7 +1,11 @@
use fs2::FileExt;
use std::fmt::Display;
use std::io;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::task::{Context, Poll};
use tempfile::NamedTempFile;
use tokio::io::{AsyncRead, ReadBuf};
use tracing::{debug, error, info, trace, warn};
pub use crate::path::*;
@ -387,3 +391,32 @@ impl Drop for LockedFile {
}
}
}
/// An asynchronous reader that reports progress as bytes are read.
pub struct ProgressReader<Reader: AsyncRead + Unpin, Callback: Fn(usize) + Unpin> {
reader: Reader,
callback: Callback,
}
impl<Reader: AsyncRead + Unpin, Callback: Fn(usize) + Unpin> ProgressReader<Reader, Callback> {
/// Create a new [`ProgressReader`] that wraps another reader.
pub fn new(reader: Reader, callback: Callback) -> Self {
Self { reader, callback }
}
}
impl<Reader: AsyncRead + Unpin, Callback: Fn(usize) + Unpin> AsyncRead
for ProgressReader<Reader, Callback>
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.as_mut().reader)
.poll_read(cx, buf)
.map_ok(|()| {
(self.callback)(buf.filled().len());
})
}
}

View File

@ -31,6 +31,7 @@ serde_json = { workspace = true }
sha2 = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true , features = ["io"] }
tracing = { workspace = true }
url = { workspace = true }

View File

@ -15,13 +15,15 @@ use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::{fmt, io};
use thiserror::Error;
use tokio::io::AsyncReadExt;
use tokio_util::io::ReaderStream;
use tracing::{debug, enabled, trace, Level};
use url::Url;
use uv_client::BaseClient;
use uv_fs::Simplified;
use uv_fs::{ProgressReader, Simplified};
use uv_metadata::read_metadata_async_seek;
#[derive(Error, Debug)]
@ -79,6 +81,13 @@ pub enum PublishSendError {
RedirectError(Url),
}
pub trait Reporter: Send + Sync + 'static {
fn on_progress(&self, name: &str, id: usize);
fn on_download_start(&self, name: &str, size: Option<u64>) -> usize;
fn on_download_progress(&self, id: usize, inc: u64);
fn on_download_complete(&self);
}
impl PublishSendError {
/// Extract `code` from the PyPI json error response, if any.
///
@ -212,6 +221,7 @@ pub async fn upload(
client: &BaseClient,
username: Option<&str>,
password: Option<&str>,
reporter: Arc<impl Reporter>,
) -> Result<bool, PublishError> {
let form_metadata = form_metadata(file, filename)
.await
@ -224,6 +234,7 @@ pub async fn upload(
username,
password,
form_metadata,
reporter,
)
.await
.map_err(|err| PublishError::PublishPrepare(file.to_path_buf(), Box::new(err)))?;
@ -396,18 +407,23 @@ async fn build_request(
username: Option<&str>,
password: Option<&str>,
form_metadata: Vec<(&'static str, String)>,
reporter: Arc<impl Reporter>,
) -> Result<RequestBuilder, PublishPrepareError> {
let mut form = reqwest::multipart::Form::new();
for (key, value) in form_metadata {
form = form.text(key, value);
}
let file: tokio::fs::File = fs_err::tokio::File::open(file).await?.into();
let file_reader = Body::from(file);
form = form.part(
"content",
Part::stream(file_reader).file_name(filename.to_string()),
);
let file = fs_err::tokio::File::open(file).await?;
let idx = reporter.on_download_start(&filename.to_string(), Some(file.metadata().await?.len()));
let reader = ProgressReader::new(file, move |read| {
reporter.on_download_progress(idx, read as u64);
});
// Stream wrapping puts a static lifetime requirement on the reader (so the request doesn't have
// a lifetime) -> callback needs to be static -> reporter reference needs to be Arc'd.
let file_reader = Body::wrap_stream(ReaderStream::new(reader));
let part = Part::stream(file_reader).file_name(filename.to_string());
form = form.part("content", part);
let url = if let Some(username) = username {
if password.is_none() {
@ -525,14 +541,26 @@ async fn handle_response(registry: &Url, response: Response) -> Result<bool, Pub
#[cfg(test)]
mod tests {
use crate::{build_request, form_metadata};
use crate::{build_request, form_metadata, Reporter};
use distribution_filename::DistFilename;
use insta::{assert_debug_snapshot, assert_snapshot};
use itertools::Itertools;
use std::path::PathBuf;
use std::sync::Arc;
use url::Url;
use uv_client::BaseClientBuilder;
struct DummyReporter;
impl Reporter for DummyReporter {
fn on_progress(&self, _name: &str, _id: usize) {}
fn on_download_start(&self, _name: &str, _size: Option<u64>) -> usize {
0
}
fn on_download_progress(&self, _id: usize, _inc: u64) {}
fn on_download_complete(&self) {}
}
/// Snapshot the data we send for an upload request for a source distribution.
#[tokio::test]
async fn upload_request_source_dist() {
@ -602,6 +630,7 @@ mod tests {
Some("ferris"),
Some("F3RR!S"),
form_metadata,
Arc::new(DummyReporter),
)
.await
.unwrap();
@ -744,6 +773,7 @@ mod tests {
Some("ferris"),
Some("F3RR!S"),
form_metadata,
Arc::new(DummyReporter),
)
.await
.unwrap();

View File

@ -1,8 +1,10 @@
use crate::commands::reporters::PublishReporter;
use crate::commands::{human_readable_bytes, ExitStatus};
use crate::printer::Printer;
use anyhow::{bail, Result};
use owo_colors::OwoColorize;
use std::fmt::Write;
use std::sync::Arc;
use tracing::info;
use url::Url;
use uv_client::{BaseClientBuilder, Connectivity};
@ -51,6 +53,7 @@ pub(crate) async fn publish(
"Uploading".bold().green(),
format!("({bytes:.1}{unit})").dimmed()
)?;
let reporter = PublishReporter::single(printer);
let uploaded = upload(
&file,
&filename,
@ -58,6 +61,8 @@ pub(crate) async fn publish(
&client,
username.as_deref(),
password.as_deref(),
// Needs to be an `Arc` because the reqwest `Body` static lifetime requirement
Arc::new(reporter),
)
.await?; // Filename and/or URL are already attached, if applicable.
info!("Upload succeeded");

View File

@ -143,9 +143,10 @@ impl ProgressReporter {
);
if size.is_some() {
// We're using binary bytes to match `human_readable_bytes`.
progress.set_style(
ProgressStyle::with_template(
"{msg:10.dim} {bar:30.green/dim} {decimal_bytes:>7}/{decimal_total_bytes:7}",
"{msg:10.dim} {bar:30.green/dim} {binary_bytes:>7}/{binary_total_bytes:7}",
)
.unwrap()
.progress_chars("--"),
@ -485,6 +486,48 @@ impl uv_python::downloads::Reporter for PythonDownloadReporter {
}
}
#[derive(Debug)]
pub(crate) struct PublishReporter {
reporter: ProgressReporter,
}
impl PublishReporter {
/// Initialize a [`PublishReporter`] for a single upload.
pub(crate) fn single(printer: Printer) -> Self {
Self::new(printer, 1)
}
/// Initialize a [`PublishReporter`] for multiple uploads.
pub(crate) fn new(printer: Printer, length: u64) -> Self {
let multi_progress = MultiProgress::with_draw_target(printer.target());
let root = multi_progress.add(ProgressBar::with_draw_target(
Some(length),
printer.target(),
));
let reporter = ProgressReporter::new(root, multi_progress, printer);
Self { reporter }
}
}
impl uv_publish::Reporter for PublishReporter {
fn on_progress(&self, _name: &str, id: usize) {
self.reporter.on_download_complete(id);
}
fn on_download_start(&self, name: &str, size: Option<u64>) -> usize {
self.reporter.on_download_start(name.to_string(), size)
}
fn on_download_progress(&self, id: usize, inc: u64) {
self.reporter.on_download_progress(id, inc);
}
fn on_download_complete(&self) {
self.reporter.root.set_message("");
self.reporter.root.finish_and_clear();
}
}
/// Like [`std::fmt::Display`], but with colors.
trait ColorDisplay {
fn to_color_string(&self) -> String;