diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/match.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/match.py index e1c8d5baf4..08b0840fd9 100644 --- a/crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/match.py +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/statement/match.py @@ -55,3 +55,56 @@ def foo(): match inside_func: # comment case "bar": pass + + +match newlines: + + # case 1 leading comment + + + case "top level case comment with newlines": # case dangling comment + # pass leading comment + pass + # pass trailing comment + + + # case 2 leading comment + + + + case "case comment with newlines" if foo == 2: # second + pass + + case "one", "newline" if (foo := 1): # third + pass + + + case "two newlines": + pass + + + + case "three newlines": + pass + case _: + pass + + +match long_lines: + case "this is a long line for if condition" if aaaaaaaaahhhhhhhh == 1 and bbbbbbaaaaaaaaaaa == 2: # comment + pass + + case "this is a long line for if condition with parentheses" if (aaaaaaaaahhhhhhhh == 1 and bbbbbbaaaaaaaaaaa == 2): # comment + pass + + case "named expressions aren't special" if foo := 1: + pass + + case "named expressions aren't that special" if (foo := 1): + pass + + case "but with already broken long lines" if ( + aaaaaaahhhhhhhhhhh == 1 and + bbbbbbbbaaaaaahhhh == 2 + ): # another comment + pass diff --git a/crates/ruff_python_formatter/src/comments/placement.rs b/crates/ruff_python_formatter/src/comments/placement.rs index 81035bf7dd..d55ec6c91b 100644 --- a/crates/ruff_python_formatter/src/comments/placement.rs +++ b/crates/ruff_python_formatter/src/comments/placement.rs @@ -213,6 +213,7 @@ fn is_first_statement_in_body(statement: AnyNodeRef, has_body: AnyNodeRef) -> bo | AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler { body, .. }) + | AnyNodeRef::MatchCase(ast::MatchCase { body, .. }) | AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. }) | AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) => { are_same_optional(statement, body.first()) diff --git a/crates/ruff_python_formatter/src/other/match_case.rs b/crates/ruff_python_formatter/src/other/match_case.rs index f2dc80f93a..bd3050af0e 100644 --- a/crates/ruff_python_formatter/src/other/match_case.rs +++ b/crates/ruff_python_formatter/src/other/match_case.rs @@ -1,8 +1,7 @@ use ruff_formatter::{write, Buffer, FormatResult}; use ruff_python_ast::MatchCase; -use crate::expression::maybe_parenthesize_expression; -use crate::expression::parentheses::Parenthesize; +use crate::comments::trailing_comments; use crate::not_yet_implemented_custom_text; use crate::prelude::*; use crate::{FormatNodeRule, PyFormatter}; @@ -19,6 +18,9 @@ impl FormatNodeRule for FormatMatchCase { body, } = item; + let comments = f.context().comments().clone(); + let dangling_item_comments = comments.dangling_comments(item); + write!( f, [ @@ -39,17 +41,21 @@ impl FormatNodeRule for FormatMatchCase { )?; if let Some(guard) = guard { - write!( - f, - [ - space(), - text("if"), - space(), - maybe_parenthesize_expression(guard, item, Parenthesize::IfBreaks) - ] - )?; + write!(f, [space(), text("if"), space(), guard.format()])?; } - write!(f, [text(":"), block_indent(&body.format())]) + write!( + f, + [ + text(":"), + trailing_comments(dangling_item_comments), + block_indent(&body.format()) + ] + ) + } + + fn fmt_dangling_comments(&self, _node: &MatchCase, _f: &mut PyFormatter) -> FormatResult<()> { + // Handled as part of `fmt_fields` + Ok(()) } } diff --git a/crates/ruff_python_formatter/src/pattern/mod.rs b/crates/ruff_python_formatter/src/pattern/mod.rs index 992f90a49d..9099772a4d 100644 --- a/crates/ruff_python_formatter/src/pattern/mod.rs +++ b/crates/ruff_python_formatter/src/pattern/mod.rs @@ -1,3 +1,8 @@ +use ruff_formatter::{FormatOwnedWithRule, FormatRefWithRule}; +use ruff_python_ast::Pattern; + +use crate::prelude::*; + pub(crate) mod pattern_match_as; pub(crate) mod pattern_match_class; pub(crate) mod pattern_match_mapping; @@ -6,3 +11,37 @@ pub(crate) mod pattern_match_sequence; pub(crate) mod pattern_match_singleton; pub(crate) mod pattern_match_star; pub(crate) mod pattern_match_value; + +#[derive(Default)] +pub struct FormatPattern; + +impl FormatRule> for FormatPattern { + fn fmt(&self, item: &Pattern, f: &mut PyFormatter) -> FormatResult<()> { + match item { + Pattern::MatchValue(p) => p.format().fmt(f), + Pattern::MatchSingleton(p) => p.format().fmt(f), + Pattern::MatchSequence(p) => p.format().fmt(f), + Pattern::MatchMapping(p) => p.format().fmt(f), + Pattern::MatchClass(p) => p.format().fmt(f), + Pattern::MatchStar(p) => p.format().fmt(f), + Pattern::MatchAs(p) => p.format().fmt(f), + Pattern::MatchOr(p) => p.format().fmt(f), + } + } +} + +impl<'ast> AsFormat> for Pattern { + type Format<'a> = FormatRefWithRule<'a, Pattern, FormatPattern, PyFormatContext<'ast>>; + + fn format(&self) -> Self::Format<'_> { + FormatRefWithRule::new(self, FormatPattern) + } +} + +impl<'ast> IntoFormat> for Pattern { + type Format = FormatOwnedWithRule>; + + fn into_format(self) -> Self::Format { + FormatOwnedWithRule::new(self, FormatPattern) + } +} diff --git a/crates/ruff_python_formatter/src/statement/stmt_match.rs b/crates/ruff_python_formatter/src/statement/stmt_match.rs index 6971950c90..d56b739590 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_match.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_match.rs @@ -1,7 +1,8 @@ -use ruff_formatter::{write, Buffer, FormatResult}; +use ruff_formatter::{format_args, write, Buffer, FormatResult}; use ruff_python_ast::StmtMatch; -use crate::comments::trailing_comments; +use crate::comments::{leading_alternate_branch_comments, trailing_comments}; +use crate::context::{NodeLevel, WithNodeLevel}; use crate::expression::maybe_parenthesize_expression; use crate::expression::parentheses::Parenthesize; use crate::prelude::*; @@ -35,8 +36,29 @@ impl FormatNodeRule for FormatStmtMatch { ] )?; - for case in cases { - write!(f, [block_indent(&case.format())])?; + let mut cases_iter = cases.iter(); + let Some(first) = cases_iter.next() else { + return Ok(()); + }; + + // The new level is for the `case` nodes. + let mut f = WithNodeLevel::new(NodeLevel::CompoundStatement, f); + + write!(f, [block_indent(&first.format())])?; + let mut last_case = first; + + for case in cases_iter { + write!( + f, + [block_indent(&format_args!( + &leading_alternate_branch_comments( + comments.leading_comments(case), + last_case.body.last(), + ), + &case.format() + ))] + )?; + last_case = case; } Ok(()) diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_complex.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_complex.py.snap index e498853715..4f33ce6e7c 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_complex.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_complex.py.snap @@ -265,7 +265,7 @@ match x: + case NOT_YET_IMPLEMENTED_Pattern: y = 0 - case [1, 0] if (x := x[:0]): -+ case NOT_YET_IMPLEMENTED_Pattern if x := x[:0]: ++ case NOT_YET_IMPLEMENTED_Pattern if (x := x[:0]): y = 1 - case [1, 0]: + case NOT_YET_IMPLEMENTED_Pattern: @@ -431,7 +431,7 @@ match (0, 1, 2): match x: case NOT_YET_IMPLEMENTED_Pattern: y = 0 - case NOT_YET_IMPLEMENTED_Pattern if x := x[:0]: + case NOT_YET_IMPLEMENTED_Pattern if (x := x[:0]): y = 1 case NOT_YET_IMPLEMENTED_Pattern: y = 2 diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_extras.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_extras.py.snap index f107f563a0..0f09fd2007 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_extras.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_extras.py.snap @@ -205,7 +205,7 @@ match bar1: assert "map" == b -@@ -59,61 +62,47 @@ +@@ -59,61 +62,51 @@ ), case, ): @@ -217,11 +217,11 @@ match bar1: - ): + case NOT_YET_IMPLEMENTED_Pattern: pass -- + - case [a as match]: + case NOT_YET_IMPLEMENTED_Pattern: pass -- + - case case: + case NOT_YET_IMPLEMENTED_Pattern: pass @@ -255,11 +255,11 @@ match bar1: - case 1 as a: + case NOT_YET_IMPLEMENTED_Pattern: pass -- + - case 2 as b, 3 as c: + case NOT_YET_IMPLEMENTED_Pattern: pass -- + - case 4 as d, (5 as e), (6 | 7 as g), *h: + case NOT_YET_IMPLEMENTED_Pattern: pass @@ -351,8 +351,10 @@ match match( ): case NOT_YET_IMPLEMENTED_Pattern: pass + case NOT_YET_IMPLEMENTED_Pattern: pass + case NOT_YET_IMPLEMENTED_Pattern: pass @@ -377,8 +379,10 @@ match something: match something: case NOT_YET_IMPLEMENTED_Pattern: pass + case NOT_YET_IMPLEMENTED_Pattern: pass + case NOT_YET_IMPLEMENTED_Pattern: pass diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_simple.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_simple.py.snap index c5fea0c542..5226bfc31a 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_simple.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__pattern_matching_simple.py.snap @@ -203,7 +203,7 @@ def where_is(point): match event.get(): - case Click((x, y), button=Button.LEFT): # This is a left click -+ case NOT_YET_IMPLEMENTED_Pattern: ++ case NOT_YET_IMPLEMENTED_Pattern: # This is a left click handle_click_at(x, y) - case Click(): + case NOT_YET_IMPLEMENTED_Pattern: @@ -306,7 +306,7 @@ match event.get(): raise ValueError(f"Unrecognized event: {other_event}") match event.get(): - case NOT_YET_IMPLEMENTED_Pattern: + case NOT_YET_IMPLEMENTED_Pattern: # This is a left click handle_click_at(x, y) case NOT_YET_IMPLEMENTED_Pattern: pass # ignore other clicks diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__remove_newline_after_match.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__remove_newline_after_match.py.snap index 2f1f0ea0b9..8bcbe70db7 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__remove_newline_after_match.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__remove_newline_after_match.py.snap @@ -31,21 +31,21 @@ def http_status(status): ```diff --- Black +++ Ruff -@@ -1,13 +1,10 @@ +@@ -1,13 +1,13 @@ def http_status(status): match status: - case 400: + case NOT_YET_IMPLEMENTED_Pattern: return "Bad request" -- + - case 401: + case NOT_YET_IMPLEMENTED_Pattern: return "Unauthorized" -- + - case 403: + case NOT_YET_IMPLEMENTED_Pattern: return "Forbidden" -- + - case 404: + case NOT_YET_IMPLEMENTED_Pattern: return "Not found" @@ -58,10 +58,13 @@ def http_status(status): match status: case NOT_YET_IMPLEMENTED_Pattern: return "Bad request" + case NOT_YET_IMPLEMENTED_Pattern: return "Unauthorized" + case NOT_YET_IMPLEMENTED_Pattern: return "Forbidden" + case NOT_YET_IMPLEMENTED_Pattern: return "Not found" ``` diff --git a/crates/ruff_python_formatter/tests/snapshots/format@statement__match.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@statement__match.py.snap index f794cbc12a..55f553ad20 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@statement__match.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@statement__match.py.snap @@ -61,6 +61,59 @@ def foo(): match inside_func: # comment case "bar": pass + + +match newlines: + + # case 1 leading comment + + + case "top level case comment with newlines": # case dangling comment + # pass leading comment + pass + # pass trailing comment + + + # case 2 leading comment + + + + case "case comment with newlines" if foo == 2: # second + pass + + case "one", "newline" if (foo := 1): # third + pass + + + case "two newlines": + pass + + + + case "three newlines": + pass + case _: + pass + + +match long_lines: + case "this is a long line for if condition" if aaaaaaaaahhhhhhhh == 1 and bbbbbbaaaaaaaaaaa == 2: # comment + pass + + case "this is a long line for if condition with parentheses" if (aaaaaaaaahhhhhhhh == 1 and bbbbbbaaaaaaaaaaa == 2): # comment + pass + + case "named expressions aren't special" if foo := 1: + pass + + case "named expressions aren't that special" if (foo := 1): + pass + + case "but with already broken long lines" if ( + aaaaaaahhhhhhhhhhh == 1 and + bbbbbbbbaaaaaahhhh == 2 + ): # another comment + pass ``` ## Output @@ -124,6 +177,52 @@ def foo(): match inside_func: # comment case NOT_YET_IMPLEMENTED_Pattern: pass + + +match newlines: + # case 1 leading comment + + case NOT_YET_IMPLEMENTED_Pattern: # case dangling comment + # pass leading comment + pass + # pass trailing comment + + # case 2 leading comment + + case NOT_YET_IMPLEMENTED_Pattern if foo == 2: # second + pass + + case NOT_YET_IMPLEMENTED_Pattern if (foo := 1): # third + pass + + case NOT_YET_IMPLEMENTED_Pattern: + pass + + case NOT_YET_IMPLEMENTED_Pattern: + pass + case NOT_YET_IMPLEMENTED_Pattern: + pass + + +match long_lines: + case NOT_YET_IMPLEMENTED_Pattern if aaaaaaaaahhhhhhhh == 1 and bbbbbbaaaaaaaaaaa == 2: # comment + pass + + case NOT_YET_IMPLEMENTED_Pattern if ( + aaaaaaaaahhhhhhhh == 1 and bbbbbbaaaaaaaaaaa == 2 + ): # comment + pass + + case NOT_YET_IMPLEMENTED_Pattern if foo := 1: + pass + + case NOT_YET_IMPLEMENTED_Pattern if (foo := 1): + pass + + case NOT_YET_IMPLEMENTED_Pattern if ( + aaaaaaahhhhhhhhhhh == 1 and bbbbbbbbaaaaaahhhh == 2 + ): # another comment + pass ```