diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/if.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/if.py index df4c488603..b3fc460e44 100644 --- a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/if.py +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/if.py @@ -39,3 +39,72 @@ d1 = [ ("b") else # 2 ("c") ] + +e1 = ( + a + if True # 1 + else b + if False # 2 + else c +) + + +# Flattening nested if-expressions. +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + (NamedValuesListIterable + if named + else FlatValuesListIterable) + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else (FlatValuesListIterable + if flat + else ValuesListIterable) + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable(1,) + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else (FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + if flat + else ValuesListIterable) + ) diff --git a/crates/ruff_python_formatter/src/expression/expr_if_exp.rs b/crates/ruff_python_formatter/src/expression/expr_if_exp.rs index e774b6d0b3..07236000c1 100644 --- a/crates/ruff_python_formatter/src/expression/expr_if_exp.rs +++ b/crates/ruff_python_formatter/src/expression/expr_if_exp.rs @@ -1,16 +1,46 @@ -use ruff_formatter::{format_args, write}; +use ruff_formatter::{write, FormatRuleWithOptions}; use ruff_python_ast::node::AnyNodeRef; -use ruff_python_ast::ExprIfExp; +use ruff_python_ast::{Expr, ExprIfExp}; use crate::comments::leading_comments; use crate::expression::parentheses::{ - in_parentheses_only_group, in_parentheses_only_soft_line_break_or_space, NeedsParentheses, - OptionalParentheses, + in_parentheses_only_group, in_parentheses_only_soft_line_break_or_space, + is_expression_parenthesized, NeedsParentheses, OptionalParentheses, }; use crate::prelude::*; +#[derive(Default, Copy, Clone)] +pub enum ExprIfExpLayout { + #[default] + Default, + + /// The [`ExprIfExp`] is nested inside another [`ExprIfExp`], so it should not be given a new + /// group. For example, avoid grouping the `else` clause in: + /// ```python + /// clone._iterable_class = ( + /// NamedValuesListIterable + /// if named + /// else FlatValuesListIterable + /// if flat + /// else ValuesListIterable + /// ) + /// ``` + Nested, +} + #[derive(Default)] -pub struct FormatExprIfExp; +pub struct FormatExprIfExp { + layout: ExprIfExpLayout, +} + +impl FormatRuleWithOptions> for FormatExprIfExp { + type Options = ExprIfExpLayout; + + fn with_options(mut self, options: Self::Options) -> Self { + self.layout = options; + self + } +} impl FormatNodeRule for FormatExprIfExp { fn fmt_fields(&self, item: &ExprIfExp, f: &mut PyFormatter) -> FormatResult<()> { @@ -22,25 +52,33 @@ impl FormatNodeRule for FormatExprIfExp { } = item; let comments = f.context().comments().clone(); - // We place `if test` and `else orelse` on a single line, so the `test` and `orelse` leading - // comments go on the line before the `if` or `else` instead of directly ahead `test` or - // `orelse` - write!( - f, - [in_parentheses_only_group(&format_args![ - body.format(), - in_parentheses_only_soft_line_break_or_space(), - leading_comments(comments.leading(test.as_ref())), - text("if"), - space(), - test.format(), - in_parentheses_only_soft_line_break_or_space(), - leading_comments(comments.leading(orelse.as_ref())), - text("else"), - space(), - orelse.format() - ])] - ) + let inner = format_with(|f: &mut PyFormatter| { + // We place `if test` and `else orelse` on a single line, so the `test` and `orelse` leading + // comments go on the line before the `if` or `else` instead of directly ahead `test` or + // `orelse` + write!( + f, + [ + body.format(), + in_parentheses_only_soft_line_break_or_space(), + leading_comments(comments.leading(test.as_ref())), + text("if"), + space(), + test.format(), + in_parentheses_only_soft_line_break_or_space(), + leading_comments(comments.leading(orelse.as_ref())), + text("else"), + space(), + ] + )?; + + FormatOrElse { orelse }.fmt(f) + }); + + match self.layout { + ExprIfExpLayout::Default => in_parentheses_only_group(&inner).fmt(f), + ExprIfExpLayout::Nested => inner.fmt(f), + } } } @@ -53,3 +91,21 @@ impl NeedsParentheses for ExprIfExp { OptionalParentheses::Multiline } } + +#[derive(Debug)] +struct FormatOrElse<'a> { + orelse: &'a Expr, +} + +impl Format> for FormatOrElse<'_> { + fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { + match self.orelse { + Expr::IfExp(expr) + if !is_expression_parenthesized(expr.into(), f.context().source()) => + { + write!(f, [expr.format().with_options(ExprIfExpLayout::Nested)]) + } + _ => write!(f, [in_parentheses_only_group(&self.orelse.format())]), + } + } +} diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@conditional_expression.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@conditional_expression.py.snap index 2536898905..f7012d70b5 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@conditional_expression.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@conditional_expression.py.snap @@ -136,6 +136,15 @@ def something(): for some_boolean_variable in some_iterable ) +@@ -86,5 +78,7 @@ + clone._iterable_class = ( + NamedValuesListIterable + if named +- else FlatValuesListIterable if flat else ValuesListIterable ++ else FlatValuesListIterable ++ if flat ++ else ValuesListIterable + ) ``` ## Ruff Output @@ -221,7 +230,9 @@ def something(): clone._iterable_class = ( NamedValuesListIterable if named - else FlatValuesListIterable if flat else ValuesListIterable + else FlatValuesListIterable + if flat + else ValuesListIterable ) ``` diff --git a/crates/ruff_python_formatter/tests/snapshots/format@expression__if.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@expression__if.py.snap index f5a195c622..7aa5a080c8 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@expression__if.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@expression__if.py.snap @@ -45,6 +45,75 @@ d1 = [ ("b") else # 2 ("c") ] + +e1 = ( + a + if True # 1 + else b + if False # 2 + else c +) + + +# Flattening nested if-expressions. +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + (NamedValuesListIterable + if named + else FlatValuesListIterable) + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else (FlatValuesListIterable + if flat + else ValuesListIterable) + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable(1,) + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else (FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + FlatValuesListIterable + if flat + else ValuesListIterable) + ) ``` ## Output @@ -96,6 +165,81 @@ d1 = [ # 2 else ("c") ] + +e1 = ( + a + if True # 1 + else b + if False # 2 + else c +) + + +# Flattening nested if-expressions. +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + (NamedValuesListIterable if named else FlatValuesListIterable) + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else (FlatValuesListIterable if flat else ValuesListIterable) + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable( + 1, + ) + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else FlatValuesListIterable + + FlatValuesListIterable + + FlatValuesListIterable + + FlatValuesListIterable + if flat + else ValuesListIterable + ) + + +def something(): + clone._iterable_class = ( + NamedValuesListIterable + if named + else ( + FlatValuesListIterable + + FlatValuesListIterable + + FlatValuesListIterable + + FlatValuesListIterable + if flat + else ValuesListIterable + ) + ) ```