Add empty lines before nested functions and classes (#6206)

## Summary

This PR ensures that if a function or class is the first statement in a
nested suite that _isn't_ a function or class body, we insert a leading
newline.

For example, given:

```python
def f():
    if True:

        def register_type():
            pass
```

We _want_ to preserve the newline, whereas today, we remove it.

Note that this only applies when the function or class doesn't have any
leading comments.

Closes https://github.com/astral-sh/ruff/issues/6066.
This commit is contained in:
Charlie Marsh 2023-08-01 11:30:59 -04:00 committed by GitHub
parent b68f76f0d9
commit 928ab63a64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 43 deletions

View File

@ -1,4 +1,4 @@
use crate::statement::suite::SuiteLevel; use crate::statement::suite::SuiteKind;
use crate::{AsFormat, FormatNodeRule, PyFormatter}; use crate::{AsFormat, FormatNodeRule, PyFormatter};
use ruff_formatter::prelude::hard_line_break; use ruff_formatter::prelude::hard_line_break;
use ruff_formatter::{write, Buffer, FormatResult}; use ruff_formatter::{write, Buffer, FormatResult};
@ -13,7 +13,7 @@ impl FormatNodeRule<ModModule> for FormatModModule {
write!( write!(
f, f,
[ [
body.format().with_options(SuiteLevel::TopLevel), body.format().with_options(SuiteKind::TopLevel),
// Trailing newline at the end of the file // Trailing newline at the end of the file
hard_line_break() hard_line_break()
] ]

View File

@ -7,6 +7,7 @@ use ruff_python_trivia::{SimpleTokenKind, SimpleTokenizer};
use crate::comments::trailing_comments; use crate::comments::trailing_comments;
use crate::expression::parentheses::{parenthesized, Parentheses}; use crate::expression::parentheses::{parenthesized, Parentheses};
use crate::prelude::*; use crate::prelude::*;
use crate::statement::suite::SuiteKind;
#[derive(Default)] #[derive(Default)]
pub struct FormatStmtClassDef; pub struct FormatStmtClassDef;
@ -52,7 +53,7 @@ impl FormatNodeRule<StmtClassDef> for FormatStmtClassDef {
[ [
text(":"), text(":"),
trailing_comments(trailing_head_comments), trailing_comments(trailing_head_comments),
block_indent(&body.format()) block_indent(&body.format().with_options(SuiteKind::Class))
] ]
) )
} }

View File

@ -8,6 +8,7 @@ use crate::comments::{leading_comments, trailing_comments};
use crate::expression::parentheses::{optional_parentheses, Parentheses}; use crate::expression::parentheses::{optional_parentheses, Parentheses};
use crate::prelude::*; use crate::prelude::*;
use crate::statement::suite::SuiteKind;
use crate::FormatNodeRule; use crate::FormatNodeRule;
#[derive(Default)] #[derive(Default)]
@ -111,7 +112,7 @@ impl FormatRule<AnyFunctionDefinition<'_>, PyFormatContext<'_>> for FormatAnyFun
[ [
text(":"), text(":"),
trailing_comments(trailing_definition_comments), trailing_comments(trailing_definition_comments),
block_indent(&item.body().format()) block_indent(&item.body().format().with_options(SuiteKind::Function))
] ]
) )
} }

View File

@ -8,38 +8,40 @@ use crate::prelude::*;
/// Level at which the [`Suite`] appears in the source code. /// Level at which the [`Suite`] appears in the source code.
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub enum SuiteLevel { pub enum SuiteKind {
/// Statements at the module level / top level /// Statements at the module level / top level
TopLevel, TopLevel,
/// Statements in a nested body /// Statements in a function body.
Nested, Function,
}
impl SuiteLevel { /// Statements in a class body.
const fn is_nested(self) -> bool { Class,
matches!(self, SuiteLevel::Nested)
} /// Statements in any other body (e.g., `if` or `while`).
Other,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct FormatSuite { pub struct FormatSuite {
level: SuiteLevel, kind: SuiteKind,
} }
impl Default for FormatSuite { impl Default for FormatSuite {
fn default() -> Self { fn default() -> Self {
FormatSuite { FormatSuite {
level: SuiteLevel::Nested, kind: SuiteKind::Other,
} }
} }
} }
impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite { impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
fn fmt(&self, statements: &Suite, f: &mut PyFormatter) -> FormatResult<()> { fn fmt(&self, statements: &Suite, f: &mut PyFormatter) -> FormatResult<()> {
let node_level = match self.level { let node_level = match self.kind {
SuiteLevel::TopLevel => NodeLevel::TopLevel, SuiteKind::TopLevel => NodeLevel::TopLevel,
SuiteLevel::Nested => NodeLevel::CompoundStatement, SuiteKind::Function | SuiteKind::Class | SuiteKind::Other => {
NodeLevel::CompoundStatement
}
}; };
let comments = f.context().comments().clone(); let comments = f.context().comments().clone();
@ -51,18 +53,33 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
}; };
let mut f = WithNodeLevel::new(node_level, f); let mut f = WithNodeLevel::new(node_level, f);
// First entry has never any separator, doesn't matter which one we take.
if matches!(self.kind, SuiteKind::Other)
&& is_class_or_function_definition(first)
&& !comments.has_leading_comments(first)
{
// Add an empty line for any nested functions or classes defined within non-function
// or class compound statements, e.g., this is stable formatting:
// ```python
// if True:
//
// def test():
// ...
// ```
write!(f, [empty_line()])?;
}
write!(f, [first.format()])?; write!(f, [first.format()])?;
let mut last = first; let mut last = first;
for statement in iter { for statement in iter {
if is_class_or_function_definition(last) || is_class_or_function_definition(statement) { if is_class_or_function_definition(last) || is_class_or_function_definition(statement) {
match self.level { match self.kind {
SuiteLevel::TopLevel => { SuiteKind::TopLevel => {
write!(f, [empty_line(), empty_line(), statement.format()])?; write!(f, [empty_line(), empty_line(), statement.format()])?;
} }
SuiteLevel::Nested => { SuiteKind::Function | SuiteKind::Class | SuiteKind::Other => {
write!(f, [empty_line(), statement.format()])?; write!(f, [empty_line(), statement.format()])?;
} }
} }
@ -95,13 +112,12 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
match lines_before(start, source) { match lines_before(start, source) {
0 | 1 => write!(f, [hard_line_break()])?, 0 | 1 => write!(f, [hard_line_break()])?,
2 => write!(f, [empty_line()])?, 2 => write!(f, [empty_line()])?,
3.. => { 3.. => match self.kind {
if self.level.is_nested() { SuiteKind::TopLevel => write!(f, [empty_line(), empty_line()])?,
SuiteKind::Function | SuiteKind::Class | SuiteKind::Other => {
write!(f, [empty_line()])?; write!(f, [empty_line()])?;
} else {
write!(f, [empty_line(), empty_line()])?;
} }
} },
} }
write!(f, [statement.format()])?; write!(f, [statement.format()])?;
@ -167,10 +183,10 @@ const fn is_import_definition(stmt: &Stmt) -> bool {
} }
impl FormatRuleWithOptions<Suite, PyFormatContext<'_>> for FormatSuite { impl FormatRuleWithOptions<Suite, PyFormatContext<'_>> for FormatSuite {
type Options = SuiteLevel; type Options = SuiteKind;
fn with_options(mut self, options: Self::Options) -> Self { fn with_options(mut self, options: Self::Options) -> Self {
self.level = options; self.kind = options;
self self
} }
} }
@ -199,10 +215,10 @@ mod tests {
use crate::comments::Comments; use crate::comments::Comments;
use crate::prelude::*; use crate::prelude::*;
use crate::statement::suite::SuiteLevel; use crate::statement::suite::SuiteKind;
use crate::PyFormatOptions; use crate::PyFormatOptions;
fn format_suite(level: SuiteLevel) -> String { fn format_suite(level: SuiteKind) -> String {
let source = r#" let source = r#"
a = 10 a = 10
@ -239,7 +255,7 @@ def trailing_func():
#[test] #[test]
fn top_level() { fn top_level() {
let formatted = format_suite(SuiteLevel::TopLevel); let formatted = format_suite(SuiteKind::TopLevel);
assert_eq!( assert_eq!(
formatted, formatted,
@ -274,7 +290,7 @@ def trailing_func():
#[test] #[test]
fn nested_level() { fn nested_level() {
let formatted = format_suite(SuiteLevel::Nested); let formatted = format_suite(SuiteKind::Other);
assert_eq!( assert_eq!(
formatted, formatted,

View File

@ -73,22 +73,13 @@ with hmm_but_this_should_get_two_preceding_newlines():
elif os.name == "nt": elif os.name == "nt":
try: try:
import msvcrt import msvcrt
@@ -45,21 +44,16 @@ @@ -54,12 +53,10 @@
pass
except ImportError:
-
def i_should_be_followed_by_only_one_newline():
pass
elif False:
-
class IHopeYouAreHavingALovelyDay: class IHopeYouAreHavingALovelyDay:
def __call__(self): def __call__(self):
print("i_should_be_followed_by_only_one_newline") print("i_should_be_followed_by_only_one_newline")
- -
else: else:
-
def foo(): def foo():
pass pass
- -
@ -146,14 +137,17 @@ elif os.name == "nt":
pass pass
except ImportError: except ImportError:
def i_should_be_followed_by_only_one_newline(): def i_should_be_followed_by_only_one_newline():
pass pass
elif False: elif False:
class IHopeYouAreHavingALovelyDay: class IHopeYouAreHavingALovelyDay:
def __call__(self): def __call__(self):
print("i_should_be_followed_by_only_one_newline") print("i_should_be_followed_by_only_one_newline")
else: else:
def foo(): def foo():
pass pass

View File

@ -344,6 +344,7 @@ def with_leading_comment():
# looking from the position of the if # looking from the position of the if
# Regression test for https://github.com/python/cpython/blob/ad56340b665c5d8ac1f318964f71697bba41acb7/Lib/logging/__init__.py#L253-L260 # Regression test for https://github.com/python/cpython/blob/ad56340b665c5d8ac1f318964f71697bba41acb7/Lib/logging/__init__.py#L253-L260
if True: if True:
def f1(): def f1():
pass # a pass # a
else: else:
@ -351,6 +352,7 @@ else:
# Here it's actually a trailing comment # Here it's actually a trailing comment
if True: if True:
def f2(): def f2():
pass pass
# a # a

View File

@ -203,14 +203,17 @@ def f():
if True: if True:
def f(): def f():
pass pass
# 1 # 1
elif True: elif True:
def f(): def f():
pass pass
# 2 # 2
else: else:
def f(): def f():
pass pass
# 3 # 3

View File

@ -263,18 +263,22 @@ except RuntimeError:
raise raise
try: try:
def f(): def f():
pass pass
# a # a
except: except:
def f(): def f():
pass pass
# b # b
else: else:
def f(): def f():
pass pass
# c # c
finally: finally:
def f(): def f():
pass pass
# d # d