diff --git a/crates/ruff/src/rules/isort/helpers.rs b/crates/ruff/src/rules/isort/helpers.rs index 53ef62c9bc..cdec6c1ace 100644 --- a/crates/ruff/src/rules/isort/helpers.rs +++ b/crates/ruff/src/rules/isort/helpers.rs @@ -3,9 +3,9 @@ use rustpython_parser::{lexer, Mode, Tok}; use ruff_python_ast::helpers::is_docstring_stmt; use ruff_python_ast::newlines::StrExt; -use ruff_python_ast::source_code::Locator; +use ruff_python_ast::source_code::{Locator, Stylist}; -use super::types::TrailingComma; +use crate::rules::isort::types::TrailingComma; /// Return `true` if a `StmtKind::ImportFrom` statement ends with a magic /// trailing comma. @@ -102,23 +102,73 @@ fn match_docstring_end(body: &[Stmt]) -> Option { Some(stmt.end_location.unwrap()) } -/// Find the end of the first token that isn't a docstring, comment, or -/// whitespace. -pub fn find_splice_location(body: &[Stmt], locator: &Locator) -> Location { - // Find the first AST node that isn't a docstring. - let mut splice = match_docstring_end(body).unwrap_or_default(); +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct Insertion { + /// The content to add before the insertion. + pub prefix: &'static str, + /// The location at which to insert. + pub location: Location, + /// The content to add after the insertion. + pub suffix: &'static str, +} - // Find the first token that isn't a comment or whitespace. - let contents = locator.skip(splice); - for (.., tok, end) in lexer::lex_located(contents, Mode::Module, splice).flatten() { +impl Insertion { + pub fn new(prefix: &'static str, location: Location, suffix: &'static str) -> Self { + Self { + prefix, + location, + suffix, + } + } +} + +/// Find the location at which a "top-of-file" import should be inserted, +/// along with a prefix and suffix to use for the insertion. +/// +/// For example, given the following code: +/// +/// ```python +/// """Hello, world!""" +/// +/// import os +/// ``` +/// +/// The location returned will be the start of the `import os` statement, +/// along with a trailing newline suffix. +pub(super) fn top_of_file_insertion( + body: &[Stmt], + locator: &Locator, + stylist: &Stylist, +) -> Insertion { + // Skip over any docstrings. + let mut location = if let Some(location) = match_docstring_end(body) { + // If the first token after the docstring is a semicolon, insert after the semicolon as an + // inline statement; + let first_token = lexer::lex_located(locator.skip(location), Mode::Module, location) + .flatten() + .next(); + if let Some((.., Tok::Semi, end)) = first_token { + return Insertion::new(" ", end, ";"); + } + + // Otherwise, advance to the next row. + Location::new(location.row() + 1, 0) + } else { + Location::default() + }; + + // Skip over any comments and empty lines. + for (.., tok, end) in + lexer::lex_located(locator.skip(location), Mode::Module, location).flatten() + { if matches!(tok, Tok::Comment(..) | Tok::Newline) { - splice = end; + location = Location::new(end.row() + 1, 0); } else { break; } } - splice + return Insertion::new("", location, stylist.line_ending().as_str()); } #[cfg(test)] @@ -126,79 +176,120 @@ mod tests { use anyhow::Result; use rustpython_parser as parser; use rustpython_parser::ast::Location; + use rustpython_parser::lexer::LexResult; - use ruff_python_ast::source_code::Locator; + use ruff_python_ast::source_code::{LineEnding, Locator, Stylist}; - use super::find_splice_location; + use crate::rules::isort::helpers::{top_of_file_insertion, Insertion}; - fn splice_contents(contents: &str) -> Result { + fn insert(contents: &str) -> Result { let program = parser::parse_program(contents, "")?; + let tokens: Vec = ruff_rustpython::tokenize(contents); let locator = Locator::new(contents); - Ok(find_splice_location(&program, &locator)) + let stylist = Stylist::from_tokens(&tokens, &locator); + Ok(top_of_file_insertion(&program, &locator, &stylist)) } #[test] - fn splice() -> Result<()> { + fn top_of_file_insertions() -> Result<()> { let contents = ""; - assert_eq!(splice_contents(contents)?, Location::new(1, 0)); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(1, 0), LineEnding::default().as_str()) + ); + + let contents = r#" +"""Hello, world!""""# + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(2, 0), LineEnding::default().as_str()) + ); let contents = r#" """Hello, world!""" "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(1, 19)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(2, 0), "\n") + ); let contents = r#" """Hello, world!""" """Hello, world!""" "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(2, 19)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(3, 0), "\n") + ); let contents = r#" x = 1 "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(1, 0)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(1, 0), "\n") + ); let contents = r#" #!/usr/bin/env python3 "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(1, 22)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(2, 0), "\n") + ); let contents = r#" #!/usr/bin/env python3 """Hello, world!""" "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(2, 19)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(3, 0), "\n") + ); let contents = r#" """Hello, world!""" #!/usr/bin/env python3 "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(2, 22)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(3, 0), "\n") + ); let contents = r#" """%s""" % "Hello, world!" "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(1, 0)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new("", Location::new(1, 0), "\n") + ); let contents = r#" """Hello, world!"""; x = 1 "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(1, 19)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new(" ", Location::new(1, 20), ";") + ); let contents = r#" """Hello, world!"""; x = 1; y = \ 2 "# - .trim(); - assert_eq!(splice_contents(contents)?, Location::new(1, 19)); + .trim_start(); + assert_eq!( + insert(contents)?, + Insertion::new(" ", Location::new(1, 20), ";") + ); Ok(()) } diff --git a/crates/ruff/src/rules/isort/rules/add_required_imports.rs b/crates/ruff/src/rules/isort/rules/add_required_imports.rs index ddf5f0158a..e7a642bbf5 100644 --- a/crates/ruff/src/rules/isort/rules/add_required_imports.rs +++ b/crates/ruff/src/rules/isort/rules/add_required_imports.rs @@ -11,9 +11,9 @@ use ruff_python_ast::source_code::{Locator, Stylist}; use ruff_python_ast::types::Range; use crate::registry::Rule; +use crate::rules::isort::helpers::{top_of_file_insertion, Insertion}; use crate::settings::{flags, Settings}; -use super::super::helpers; use super::super::track::Block; /// ## What it does @@ -169,32 +169,15 @@ fn add_required_import( Range::new(Location::default(), Location::default()), ); if autofix.into() && settings.rules.should_fix(Rule::MissingRequiredImport) { - // Determine the location at which the import should be inserted. - let splice = helpers::find_splice_location(python_ast, locator); - - // Generate the edit. - let mut contents = String::with_capacity(required_import.len() + 1); - - // Newline (LF/CRLF) - let line_sep = stylist.line_ending().as_str(); - - // If we're inserting beyond the start of the file, we add - // a newline _before_, since the splice represents the _end_ of the last - // irrelevant token (e.g., the end of a comment or the end of - // docstring). This ensures that we properly handle awkward cases like - // docstrings that are followed by semicolons. - if splice > Location::default() { - contents.push_str(line_sep); - } - contents.push_str(&required_import); - - // If we're inserting at the start of the file, add a trailing newline instead. - if splice == Location::default() { - contents.push_str(line_sep); - } - - // Construct the fix. - diagnostic.set_fix(Edit::insertion(contents, splice)); + let Insertion { + prefix, + location, + suffix, + } = top_of_file_insertion(python_ast, locator, stylist); + diagnostic.set_fix(Edit::insertion( + format!("{prefix}{required_import}{suffix}"), + location, + )); } Some(diagnostic) } @@ -224,8 +207,8 @@ pub fn add_required_imports( ); return vec![]; } - - match &body[0].node { + let stmt = &body[0]; + match &stmt.node { StmtKind::ImportFrom { module, names, diff --git a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__combined_required_imports_docstring.py.snap b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__combined_required_imports_docstring.py.snap index 3939bf2380..624bbfeb11 100644 --- a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__combined_required_imports_docstring.py.snap +++ b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__combined_required_imports_docstring.py.snap @@ -15,13 +15,13 @@ expression: diagnostics column: 0 fix: edits: - - content: "\nfrom __future__ import annotations" + - content: "from __future__ import annotations\n" location: - row: 1 - column: 19 + row: 2 + column: 0 end_location: - row: 1 - column: 19 + row: 2 + column: 0 parent: ~ - kind: name: MissingRequiredImport @@ -36,12 +36,12 @@ expression: diagnostics column: 0 fix: edits: - - content: "\nfrom __future__ import generator_stop" + - content: "from __future__ import generator_stop\n" location: - row: 1 - column: 19 + row: 2 + column: 0 end_location: - row: 1 - column: 19 + row: 2 + column: 0 parent: ~ diff --git a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_docstring.py.snap b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_docstring.py.snap index ec50107c10..e5abb9e4e9 100644 --- a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_docstring.py.snap +++ b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_docstring.py.snap @@ -15,12 +15,12 @@ expression: diagnostics column: 0 fix: edits: - - content: "\nfrom __future__ import annotations" + - content: "from __future__ import annotations\n" location: - row: 1 - column: 19 + row: 2 + column: 0 end_location: - row: 1 - column: 19 + row: 2 + column: 0 parent: ~ diff --git a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_multiline_docstring.py.snap b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_multiline_docstring.py.snap index 70ff9a5c99..cfb0b579d4 100644 --- a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_multiline_docstring.py.snap +++ b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_import_multiline_docstring.py.snap @@ -15,12 +15,12 @@ expression: diagnostics column: 0 fix: edits: - - content: "\nfrom __future__ import annotations" + - content: "from __future__ import annotations\n" location: - row: 3 - column: 3 + row: 4 + column: 0 end_location: - row: 3 - column: 3 + row: 4 + column: 0 parent: ~ diff --git a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_imports_docstring.py.snap b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_imports_docstring.py.snap index 3939bf2380..624bbfeb11 100644 --- a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_imports_docstring.py.snap +++ b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__required_imports_docstring.py.snap @@ -15,13 +15,13 @@ expression: diagnostics column: 0 fix: edits: - - content: "\nfrom __future__ import annotations" + - content: "from __future__ import annotations\n" location: - row: 1 - column: 19 + row: 2 + column: 0 end_location: - row: 1 - column: 19 + row: 2 + column: 0 parent: ~ - kind: name: MissingRequiredImport @@ -36,12 +36,12 @@ expression: diagnostics column: 0 fix: edits: - - content: "\nfrom __future__ import generator_stop" + - content: "from __future__ import generator_stop\n" location: - row: 1 - column: 19 + row: 2 + column: 0 end_location: - row: 1 - column: 19 + row: 2 + column: 0 parent: ~ diff --git a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__straight_required_import_docstring.py.snap b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__straight_required_import_docstring.py.snap index 2dd5327452..7ab2e9283e 100644 --- a/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__straight_required_import_docstring.py.snap +++ b/crates/ruff/src/rules/isort/snapshots/ruff__rules__isort__tests__straight_required_import_docstring.py.snap @@ -15,12 +15,12 @@ expression: diagnostics column: 0 fix: edits: - - content: "\nimport os" + - content: "import os\n" location: - row: 1 - column: 19 + row: 2 + column: 0 end_location: - row: 1 - column: 19 + row: 2 + column: 0 parent: ~