diff --git a/Cargo.lock b/Cargo.lock index f6fd46fbfe..76c011607a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2322,6 +2322,7 @@ dependencies = [ "bitflags 2.4.0", "clap", "countme", + "indoc", "insta", "itertools 0.11.0", "memchr", diff --git a/crates/ruff_python_formatter/Cargo.toml b/crates/ruff_python_formatter/Cargo.toml index 639e350d1f..cdda7c5881 100644 --- a/crates/ruff_python_formatter/Cargo.toml +++ b/crates/ruff_python_formatter/Cargo.toml @@ -43,6 +43,7 @@ insta = { workspace = true, features = ["glob"] } serde = { workspace = true } serde_json = { workspace = true } similar = { workspace = true } +indoc = "2.0.4" [[test]] name = "ruff_python_formatter_fixtures" diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index c8af1ca607..32237d92a2 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -5,14 +5,14 @@ use tracing::{warn, Level}; use ruff_formatter::prelude::*; use ruff_formatter::{format, FormatError, Formatted, PrintError, Printed, SourceCode}; -use ruff_python_ast::node::{AnyNodeRef, AstNode}; +use ruff_python_ast::node::AstNode; use ruff_python_ast::{ Mod, Stmt, StmtClassDef, StmtFor, StmtFunctionDef, StmtIf, StmtWhile, StmtWith, }; use ruff_python_index::tokens_and_ranges; use ruff_python_parser::lexer::LexicalError; use ruff_python_parser::{parse_ok_tokens, Mode, ParseError}; -use ruff_python_trivia::CommentRanges; +use ruff_python_trivia::{is_python_whitespace, CommentRanges}; use ruff_source_file::Locator; use ruff_text_size::{Ranged, TextLen, TextRange, TextSize}; @@ -36,6 +36,7 @@ mod options; pub(crate) mod other; pub(crate) mod pattern; mod prelude; +mod range_formatting; mod settings; pub(crate) mod statement; pub(crate) mod type_param; @@ -228,6 +229,34 @@ pub fn format_module_ast<'a>( Ok(formatted) } +/// Is range inside the body of a node, if we consider the whitespace surrounding the suite as part +/// of the body? +/// +/// TODO: Handle leading comments on the first statement +fn range_in_body(suite: &[Stmt], range: TextRange, source: &str) -> bool { + let suite_start = suite.first().unwrap().start(); + let suite_end = suite.last().unwrap().end(); + + if range.start() < suite_start + // Extend the range include all whitespace prior to the first statement + && !source[TextRange::new(range.start(), suite_start)] + .chars() + .all(|c| is_python_whitespace(c)) + { + return false; + } + if range.end() > suite_end + // Extend the range include all whitespace after to the last statement + && !source[TextRange::new(suite_end,range.end())] + .chars() + .all(|c| is_python_whitespace(c)) + { + return false; + } + + true +} + pub fn format_module_range<'a>( module: &'a Mod, comment_ranges: &'a CommentRanges, @@ -262,20 +291,24 @@ pub fn format_module_range<'a>( // ``` // TODO: If it goes beyond the end of the last stmt or before start, do we need to format // the parent? - // TODO: Change suite formatting so we can use a slice instead - let mut parent_body = &module_inner.body; - let mut in_range: Vec; + let mut parent_body: &[Stmt] = module_inner.body.as_slice(); + let mut in_range; + // TODO: Allow partial inclusions, e.g. + // ```python + // not_formatted = 0 + // start = 1 + // if cond_formatted: + // last_formatted = 2 + // not_formatted_anymore = 3 + // ``` + // prob a slice and an optional trailing arg let in_range = loop { - in_range = parent_body - .into_iter() - // TODO: check whether these bounds need equality - .skip_while(|child| range.start() > child.end()) - .take_while(|child| child.start() < range.end()) - .cloned() - .collect(); + let start = parent_body.partition_point(|child| child.end() < range.start()); + let end = parent_body.partition_point(|child| child.start() < range.end()); + in_range = &parent_body[start..end]; - let [single_stmt] = in_range.as_slice() else { + let [single_stmt] = in_range else { break in_range; }; @@ -287,9 +320,7 @@ pub fn format_module_range<'a>( | Stmt::ClassDef(StmtClassDef { body, .. }) => { // We need to format the header or a trailing comment // TODO: ignore trivia - if range.start() < body.first().unwrap().start() - || range.end() > body.last().unwrap().end() - { + if range_in_body(body, range, source) { break in_range; } else { parent_body = &body; @@ -300,22 +331,20 @@ pub fn format_module_range<'a>( elif_else_clauses, .. }) => { - if range.start() < body.first().unwrap().start() - || range.end() - > elif_else_clauses - .last() - .map(|clause| clause.body.last().unwrap().end()) - .unwrap_or(body.lsa) - { + let if_all_end = TextRange::new( + range.start(), + elif_else_clauses + .last() + .map(|clause| clause.body.last().unwrap().end()) + .unwrap_or(body.last().unwrap().end()), + ); + if !range_in_body(body, if_all_end, source) { break in_range; } else if let Some(body) = iter::once(body) - .chain(elif_else_clauses.iter().map(|clause| clause.body)) - .find(|body| { - body.first().unwrap().start() <= range.start() - && range.end() <= body.last().unwrap().end() - }) + .chain(elif_else_clauses.iter().map(|clause| &clause.body)) + .find(|body| range_in_body(body, range, source)) { - in_range = body.clone(); + parent_body = &body; } else { break in_range; } @@ -339,7 +368,8 @@ pub fn format_module_range<'a>( let formatted: Formatted = format!( PyFormatContext::new(options.clone(), locator.contents(), comments), - [in_range.format().with_options(SuiteKind::TopLevel)] + // TODO: Make suite formatting accept slices + [in_range.to_vec().format().with_options(SuiteKind::TopLevel)] )?; //println!("{}", formatted.document().display(SourceCode::new(source))); // TODO: Make the printer use the buffer instead diff --git a/crates/ruff_python_formatter/src/range_formatting.rs b/crates/ruff_python_formatter/src/range_formatting.rs new file mode 100644 index 0000000000..cf52d934a7 --- /dev/null +++ b/crates/ruff_python_formatter/src/range_formatting.rs @@ -0,0 +1,226 @@ +#[cfg(test)] +mod tests { + use crate::{format_module_source_range, LspRowColumn, PyFormatOptions}; + use indoc::indoc; + use insta::assert_snapshot; + + fn format(source: &str, start: (usize, usize), end: (usize, usize)) -> String { + format_module_source_range( + source, + PyFormatOptions::default(), + Some(LspRowColumn { + row: start.0, + col: start.1, + }), + Some(LspRowColumn { + row: end.0, + col: end.1, + }), + ) + .unwrap() + } + + #[test] + fn test_top_level() { + assert_snapshot!(format(indoc! {r#" + a = [1,] + b = [1,] + c = [1,] + d = [1,] + "#}, (1, 3), (2, 5)), @r###" + a = [1,] + b = [ + 1, + ] + c = [ + 1, + ] + d = [1,] + "###); + } + + #[test] + fn test_easy_nested() { + assert_snapshot!(format(indoc! {r#" + a = [1,] + for i in range( 1 ): + b = [1,] + c = [1,] + d = [1,] + e = [1,] + "#}, (3, 3), (3, 5)), @r###" + a = [1,] + for i in range(1): + b = [ + 1, + ] + c = [ + 1, + ] + d = [ + 1, + ] + + e = [1,] + "###); + } + + #[test] + fn test_if() { + let source = indoc! {r#" + import random + if random.random() < 0.5: + a = [1,] + b = [1,] + elif random.random() < 0.75: + c = [1,] + d = [1,] + else: + e = [1,] + f = [1,] + g = [1,] + "#}; + + assert_snapshot!(format(source, (3, 0), (3, 10)), @r###" + import random + if random.random() < 0.5: + a = [ + 1, + ] + b = [ + 1, + ] + elif random.random() < 0.75: + c = [ + 1, + ] + d = [ + 1, + ] + else: + e = [ + 1, + ] + f = [ + 1, + ] + + g = [1,] + "###); + assert_snapshot!(format(source, (6, 0), (6, 10)), @r###" + import random + if random.random() < 0.5: + a = [ + 1, + ] + b = [ + 1, + ] + elif random.random() < 0.75: + c = [ + 1, + ] + d = [ + 1, + ] + else: + e = [ + 1, + ] + f = [ + 1, + ] + + g = [1,] + "###); + assert_snapshot!(format(source, (9, 0), (9, 10)), @r###" + import random + if random.random() < 0.5: + a = [ + 1, + ] + b = [ + 1, + ] + elif random.random() < 0.75: + c = [ + 1, + ] + d = [ + 1, + ] + else: + e = [ + 1, + ] + f = [ + 1, + ] + + g = [1,] + "###); + assert_snapshot!(format(source, (3, 0), (6, 10)), @r###" + import random + if random.random() < 0.5: + a = [ + 1, + ] + b = [ + 1, + ] + elif random.random() < 0.75: + c = [ + 1, + ] + d = [ + 1, + ] + else: + e = [ + 1, + ] + f = [ + 1, + ] + + g = [1,] + "###); + } + + // TODO + #[test] + fn test_trailing_comment() { + assert_snapshot!(format(indoc! {r#" + if True: + a = [1,] + # trailing comment + "#}, (1, 3), (2, 5)), @r###" + if True: + a = [ + 1, + ] + + # trailing comment + "###); + } + + // TODO + #[test] + fn test_alternative_indent() { + assert_snapshot!(format(indoc! {r#" + if True: + a = [1,] + b = [1,] + c = [1,] + "#}, (1, 3), (2, 5)), @r###" + if True: + a = [ + 1, + ] + b = [ + 1, + ] + c = [1,] + "###); + } +}