diff --git a/crates/ruff/src/rules/isort/rules/add_required_imports.rs b/crates/ruff/src/rules/isort/rules/add_required_imports.rs index e7a642bbf5..0ab96e53f3 100644 --- a/crates/ruff/src/rules/isort/rules/add_required_imports.rs +++ b/crates/ruff/src/rules/isort/rules/add_required_imports.rs @@ -1,5 +1,3 @@ -use std::fmt; - use log::error; use rustpython_parser as parser; use rustpython_parser::ast::{Location, StmtKind, Suite}; @@ -7,6 +5,7 @@ use rustpython_parser::ast::{Location, StmtKind, Suite}; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::is_docstring_stmt; +use ruff_python_ast::imports::{Alias, AnyImport, Import, ImportFrom}; use ruff_python_ast::source_code::{Locator, Stylist}; use ruff_python_ast::types::Range; @@ -55,59 +54,6 @@ impl AlwaysAutofixableViolation for MissingRequiredImport { } } -struct Alias<'a> { - name: &'a str, - as_name: Option<&'a str>, -} - -struct ImportFrom<'a> { - module: Option<&'a str>, - name: Alias<'a>, - level: Option<&'a usize>, -} - -struct Import<'a> { - name: Alias<'a>, -} - -enum AnyImport<'a> { - Import(Import<'a>), - ImportFrom(ImportFrom<'a>), -} - -impl fmt::Display for ImportFrom<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "from ")?; - if let Some(level) = self.level { - write!(f, "{}", ".".repeat(*level))?; - } - if let Some(module) = self.module { - write!(f, "{module}")?; - } - write!(f, " import {}", self.name.name)?; - Ok(()) - } -} - -impl fmt::Display for Import<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "import {}", self.name.name)?; - if let Some(as_name) = self.name.as_name { - write!(f, " as {as_name}")?; - } - Ok(()) - } -} - -impl fmt::Display for AnyImport<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - AnyImport::Import(import) => write!(f, "{import}"), - AnyImport::ImportFrom(import_from) => write!(f, "{import_from}"), - } - } -} - fn contains(block: &Block, required_import: &AnyImport) -> bool { block.imports.iter().any(|import| match required_import { AnyImport::Import(required_import) => { @@ -130,7 +76,7 @@ fn contains(block: &Block, required_import: &AnyImport) -> bool { return false; }; module.as_deref() == required_import.module - && level.as_ref() == required_import.level + && *level == required_import.level && names.iter().any(|alias| { alias.node.name == required_import.name.name && alias.node.asname.as_deref() == required_import.name.as_name @@ -223,7 +169,7 @@ pub fn add_required_imports( name: name.node.name.as_str(), as_name: name.node.asname.as_deref(), }, - level: level.as_ref(), + level: *level, }), blocks, python_ast, diff --git a/crates/ruff_python_ast/src/imports.rs b/crates/ruff_python_ast/src/imports.rs new file mode 100644 index 0000000000..0565efc89a --- /dev/null +++ b/crates/ruff_python_ast/src/imports.rs @@ -0,0 +1,61 @@ +use std::fmt; + +/// A representation of an individual name imported via any import statement. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AnyImport<'a> { + Import(Import<'a>), + ImportFrom(ImportFrom<'a>), +} + +/// A representation of an individual name imported via an `import` statement. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Import<'a> { + pub name: Alias<'a>, +} + +/// A representation of an individual name imported via a `from ... import` statement. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ImportFrom<'a> { + pub module: Option<&'a str>, + pub name: Alias<'a>, + pub level: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Alias<'a> { + pub name: &'a str, + pub as_name: Option<&'a str>, +} + +impl fmt::Display for AnyImport<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AnyImport::Import(import) => write!(f, "{import}"), + AnyImport::ImportFrom(import_from) => write!(f, "{import_from}"), + } + } +} + +impl fmt::Display for Import<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "import {}", self.name.name)?; + if let Some(as_name) = self.name.as_name { + write!(f, " as {as_name}")?; + } + Ok(()) + } +} + +impl fmt::Display for ImportFrom<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "from ")?; + if let Some(level) = self.level { + write!(f, "{}", ".".repeat(level))?; + } + if let Some(module) = self.module { + write!(f, "{module}")?; + } + write!(f, " import {}", self.name.name)?; + Ok(()) + } +} diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index 872547620f..f7f0223c29 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -5,6 +5,7 @@ pub mod context; pub mod function_type; pub mod hashable; pub mod helpers; +pub mod imports; pub mod logging; pub mod newlines; pub mod operations;