Add test cases

This commit is contained in:
konstin 2023-09-22 15:50:34 +02:00
parent 53b5121f30
commit de239ace74
4 changed files with 288 additions and 30 deletions

1
Cargo.lock generated
View File

@ -2322,6 +2322,7 @@ dependencies = [
"bitflags 2.4.0", "bitflags 2.4.0",
"clap", "clap",
"countme", "countme",
"indoc",
"insta", "insta",
"itertools 0.11.0", "itertools 0.11.0",
"memchr", "memchr",

View File

@ -43,6 +43,7 @@ insta = { workspace = true, features = ["glob"] }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
similar = { workspace = true } similar = { workspace = true }
indoc = "2.0.4"
[[test]] [[test]]
name = "ruff_python_formatter_fixtures" name = "ruff_python_formatter_fixtures"

View File

@ -5,14 +5,14 @@ use tracing::{warn, Level};
use ruff_formatter::prelude::*; use ruff_formatter::prelude::*;
use ruff_formatter::{format, FormatError, Formatted, PrintError, Printed, SourceCode}; use ruff_formatter::{format, FormatError, Formatted, PrintError, Printed, SourceCode};
use ruff_python_ast::node::{AnyNodeRef, AstNode}; use ruff_python_ast::node::AstNode;
use ruff_python_ast::{ use ruff_python_ast::{
Mod, Stmt, StmtClassDef, StmtFor, StmtFunctionDef, StmtIf, StmtWhile, StmtWith, Mod, Stmt, StmtClassDef, StmtFor, StmtFunctionDef, StmtIf, StmtWhile, StmtWith,
}; };
use ruff_python_index::tokens_and_ranges; use ruff_python_index::tokens_and_ranges;
use ruff_python_parser::lexer::LexicalError; use ruff_python_parser::lexer::LexicalError;
use ruff_python_parser::{parse_ok_tokens, Mode, ParseError}; use ruff_python_parser::{parse_ok_tokens, Mode, ParseError};
use ruff_python_trivia::CommentRanges; use ruff_python_trivia::{is_python_whitespace, CommentRanges};
use ruff_source_file::Locator; use ruff_source_file::Locator;
use ruff_text_size::{Ranged, TextLen, TextRange, TextSize}; use ruff_text_size::{Ranged, TextLen, TextRange, TextSize};
@ -36,6 +36,7 @@ mod options;
pub(crate) mod other; pub(crate) mod other;
pub(crate) mod pattern; pub(crate) mod pattern;
mod prelude; mod prelude;
mod range_formatting;
mod settings; mod settings;
pub(crate) mod statement; pub(crate) mod statement;
pub(crate) mod type_param; pub(crate) mod type_param;
@ -228,6 +229,34 @@ pub fn format_module_ast<'a>(
Ok(formatted) Ok(formatted)
} }
/// Is range inside the body of a node, if we consider the whitespace surrounding the suite as part
/// of the body?
///
/// TODO: Handle leading comments on the first statement
fn range_in_body(suite: &[Stmt], range: TextRange, source: &str) -> bool {
let suite_start = suite.first().unwrap().start();
let suite_end = suite.last().unwrap().end();
if range.start() < suite_start
// Extend the range include all whitespace prior to the first statement
&& !source[TextRange::new(range.start(), suite_start)]
.chars()
.all(|c| is_python_whitespace(c))
{
return false;
}
if range.end() > suite_end
// Extend the range include all whitespace after to the last statement
&& !source[TextRange::new(suite_end,range.end())]
.chars()
.all(|c| is_python_whitespace(c))
{
return false;
}
true
}
pub fn format_module_range<'a>( pub fn format_module_range<'a>(
module: &'a Mod, module: &'a Mod,
comment_ranges: &'a CommentRanges, comment_ranges: &'a CommentRanges,
@ -262,20 +291,24 @@ pub fn format_module_range<'a>(
// ``` // ```
// TODO: If it goes beyond the end of the last stmt or before start, do we need to format // TODO: If it goes beyond the end of the last stmt or before start, do we need to format
// the parent? // the parent?
// TODO: Change suite formatting so we can use a slice instead let mut parent_body: &[Stmt] = module_inner.body.as_slice();
let mut parent_body = &module_inner.body; let mut in_range;
let mut in_range: Vec<Stmt>;
// TODO: Allow partial inclusions, e.g.
// ```python
// not_formatted = 0
// start = 1
// if cond_formatted:
// last_formatted = 2
// not_formatted_anymore = 3
// ```
// prob a slice and an optional trailing arg
let in_range = loop { let in_range = loop {
in_range = parent_body let start = parent_body.partition_point(|child| child.end() < range.start());
.into_iter() let end = parent_body.partition_point(|child| child.start() < range.end());
// TODO: check whether these bounds need equality in_range = &parent_body[start..end];
.skip_while(|child| range.start() > child.end())
.take_while(|child| child.start() < range.end())
.cloned()
.collect();
let [single_stmt] = in_range.as_slice() else { let [single_stmt] = in_range else {
break in_range; break in_range;
}; };
@ -287,9 +320,7 @@ pub fn format_module_range<'a>(
| Stmt::ClassDef(StmtClassDef { body, .. }) => { | Stmt::ClassDef(StmtClassDef { body, .. }) => {
// We need to format the header or a trailing comment // We need to format the header or a trailing comment
// TODO: ignore trivia // TODO: ignore trivia
if range.start() < body.first().unwrap().start() if range_in_body(body, range, source) {
|| range.end() > body.last().unwrap().end()
{
break in_range; break in_range;
} else { } else {
parent_body = &body; parent_body = &body;
@ -300,22 +331,20 @@ pub fn format_module_range<'a>(
elif_else_clauses, elif_else_clauses,
.. ..
}) => { }) => {
if range.start() < body.first().unwrap().start() let if_all_end = TextRange::new(
|| range.end() range.start(),
> elif_else_clauses elif_else_clauses
.last() .last()
.map(|clause| clause.body.last().unwrap().end()) .map(|clause| clause.body.last().unwrap().end())
.unwrap_or(body.lsa) .unwrap_or(body.last().unwrap().end()),
{ );
if !range_in_body(body, if_all_end, source) {
break in_range; break in_range;
} else if let Some(body) = iter::once(body) } else if let Some(body) = iter::once(body)
.chain(elif_else_clauses.iter().map(|clause| clause.body)) .chain(elif_else_clauses.iter().map(|clause| &clause.body))
.find(|body| { .find(|body| range_in_body(body, range, source))
body.first().unwrap().start() <= range.start()
&& range.end() <= body.last().unwrap().end()
})
{ {
in_range = body.clone(); parent_body = &body;
} else { } else {
break in_range; break in_range;
} }
@ -339,7 +368,8 @@ pub fn format_module_range<'a>(
let formatted: Formatted<PyFormatContext> = format!( let formatted: Formatted<PyFormatContext> = format!(
PyFormatContext::new(options.clone(), locator.contents(), comments), PyFormatContext::new(options.clone(), locator.contents(), comments),
[in_range.format().with_options(SuiteKind::TopLevel)] // TODO: Make suite formatting accept slices
[in_range.to_vec().format().with_options(SuiteKind::TopLevel)]
)?; )?;
//println!("{}", formatted.document().display(SourceCode::new(source))); //println!("{}", formatted.document().display(SourceCode::new(source)));
// TODO: Make the printer use the buffer instead // TODO: Make the printer use the buffer instead

View File

@ -0,0 +1,226 @@
#[cfg(test)]
mod tests {
use crate::{format_module_source_range, LspRowColumn, PyFormatOptions};
use indoc::indoc;
use insta::assert_snapshot;
fn format(source: &str, start: (usize, usize), end: (usize, usize)) -> String {
format_module_source_range(
source,
PyFormatOptions::default(),
Some(LspRowColumn {
row: start.0,
col: start.1,
}),
Some(LspRowColumn {
row: end.0,
col: end.1,
}),
)
.unwrap()
}
#[test]
fn test_top_level() {
assert_snapshot!(format(indoc! {r#"
a = [1,]
b = [1,]
c = [1,]
d = [1,]
"#}, (1, 3), (2, 5)), @r###"
a = [1,]
b = [
1,
]
c = [
1,
]
d = [1,]
"###);
}
#[test]
fn test_easy_nested() {
assert_snapshot!(format(indoc! {r#"
a = [1,]
for i in range( 1 ):
b = [1,]
c = [1,]
d = [1,]
e = [1,]
"#}, (3, 3), (3, 5)), @r###"
a = [1,]
for i in range(1):
b = [
1,
]
c = [
1,
]
d = [
1,
]
e = [1,]
"###);
}
#[test]
fn test_if() {
let source = indoc! {r#"
import random
if random.random() < 0.5:
a = [1,]
b = [1,]
elif random.random() < 0.75:
c = [1,]
d = [1,]
else:
e = [1,]
f = [1,]
g = [1,]
"#};
assert_snapshot!(format(source, (3, 0), (3, 10)), @r###"
import random
if random.random() < 0.5:
a = [
1,
]
b = [
1,
]
elif random.random() < 0.75:
c = [
1,
]
d = [
1,
]
else:
e = [
1,
]
f = [
1,
]
g = [1,]
"###);
assert_snapshot!(format(source, (6, 0), (6, 10)), @r###"
import random
if random.random() < 0.5:
a = [
1,
]
b = [
1,
]
elif random.random() < 0.75:
c = [
1,
]
d = [
1,
]
else:
e = [
1,
]
f = [
1,
]
g = [1,]
"###);
assert_snapshot!(format(source, (9, 0), (9, 10)), @r###"
import random
if random.random() < 0.5:
a = [
1,
]
b = [
1,
]
elif random.random() < 0.75:
c = [
1,
]
d = [
1,
]
else:
e = [
1,
]
f = [
1,
]
g = [1,]
"###);
assert_snapshot!(format(source, (3, 0), (6, 10)), @r###"
import random
if random.random() < 0.5:
a = [
1,
]
b = [
1,
]
elif random.random() < 0.75:
c = [
1,
]
d = [
1,
]
else:
e = [
1,
]
f = [
1,
]
g = [1,]
"###);
}
// TODO
#[test]
fn test_trailing_comment() {
assert_snapshot!(format(indoc! {r#"
if True:
a = [1,]
# trailing comment
"#}, (1, 3), (2, 5)), @r###"
if True:
a = [
1,
]
# trailing comment
"###);
}
// TODO
#[test]
fn test_alternative_indent() {
assert_snapshot!(format(indoc! {r#"
if True:
a = [1,]
b = [1,]
c = [1,]
"#}, (1, 3), (2, 5)), @r###"
if True:
a = [
1,
]
b = [
1,
]
c = [1,]
"###);
}
}