mirror of https://github.com/astral-sh/ruff
[ty] Fix auto import for files with `from __future__` import (#20987)
This commit is contained in:
parent
a802d7a0ea
commit
24d0f65d62
|
|
@ -145,8 +145,12 @@ impl<'a> Importer<'a> {
|
||||||
let request = request.avoid_conflicts(self.db, self.file, members);
|
let request = request.avoid_conflicts(self.db, self.file, members);
|
||||||
let mut symbol_text: Box<str> = request.member.into();
|
let mut symbol_text: Box<str> = request.member.into();
|
||||||
let Some(response) = self.find(&request, members.at) else {
|
let Some(response) = self.find(&request, members.at) else {
|
||||||
let import = Insertion::start_of_file(self.parsed.suite(), self.source, self.stylist)
|
let insertion = if let Some(future) = self.find_last_future_import() {
|
||||||
.into_edit(&request.to_string());
|
Insertion::end_of_statement(future.stmt, self.source, self.stylist)
|
||||||
|
} else {
|
||||||
|
Insertion::start_of_file(self.parsed.suite(), self.source, self.stylist)
|
||||||
|
};
|
||||||
|
let import = insertion.into_edit(&request.to_string());
|
||||||
if matches!(request.style, ImportStyle::Import) {
|
if matches!(request.style, ImportStyle::Import) {
|
||||||
symbol_text = format!("{}.{}", request.module, request.member).into();
|
symbol_text = format!("{}.{}", request.module, request.member).into();
|
||||||
}
|
}
|
||||||
|
|
@ -241,6 +245,19 @@ impl<'a> Importer<'a> {
|
||||||
}
|
}
|
||||||
choice
|
choice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Find the last `from __future__` import statement in the AST.
|
||||||
|
fn find_last_future_import(&self) -> Option<&'a AstImport> {
|
||||||
|
self.imports
|
||||||
|
.iter()
|
||||||
|
.take_while(|import| {
|
||||||
|
import
|
||||||
|
.stmt
|
||||||
|
.as_import_from_stmt()
|
||||||
|
.is_some_and(|import_from| import_from.module.as_deref() == Some("__future__"))
|
||||||
|
})
|
||||||
|
.last()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A map of symbols in scope at a particular location in a module.
|
/// A map of symbols in scope at a particular location in a module.
|
||||||
|
|
@ -1293,6 +1310,44 @@ def foo():
|
||||||
");
|
");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn existing_future_import() {
|
||||||
|
let test = cursor_test(
|
||||||
|
"\
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
<CURSOR>
|
||||||
|
",
|
||||||
|
);
|
||||||
|
assert_snapshot!(
|
||||||
|
test.import("typing", "TypeVar"), @r"
|
||||||
|
from __future__ import annotations
|
||||||
|
import typing
|
||||||
|
|
||||||
|
typing.TypeVar
|
||||||
|
");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn existing_future_import_after_docstring() {
|
||||||
|
let test = cursor_test(
|
||||||
|
r#"
|
||||||
|
"This is a module level docstring"
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
<CURSOR>
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
assert_snapshot!(
|
||||||
|
test.import("typing", "TypeVar"), @r#"
|
||||||
|
"This is a module level docstring"
|
||||||
|
from __future__ import annotations
|
||||||
|
import typing
|
||||||
|
|
||||||
|
typing.TypeVar
|
||||||
|
"#);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn qualify_symbol_to_avoid_overwriting_other_symbol_in_scope() {
|
fn qualify_symbol_to_avoid_overwriting_other_symbol_in_scope() {
|
||||||
let test = cursor_test(
|
let test = cursor_test(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue