[ty] Fix auto import for files with `from __future__` import (#20987)

This commit is contained in:
Micha Reiser 2025-10-21 08:14:39 +02:00 committed by GitHub
parent a802d7a0ea
commit 24d0f65d62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 57 additions and 2 deletions

View File

@ -145,8 +145,12 @@ impl<'a> Importer<'a> {
let request = request.avoid_conflicts(self.db, self.file, members);
let mut symbol_text: Box<str> = request.member.into();
let Some(response) = self.find(&request, members.at) else {
let import = Insertion::start_of_file(self.parsed.suite(), self.source, self.stylist)
.into_edit(&request.to_string());
let insertion = if let Some(future) = self.find_last_future_import() {
Insertion::end_of_statement(future.stmt, self.source, self.stylist)
} else {
Insertion::start_of_file(self.parsed.suite(), self.source, self.stylist)
};
let import = insertion.into_edit(&request.to_string());
if matches!(request.style, ImportStyle::Import) {
symbol_text = format!("{}.{}", request.module, request.member).into();
}
@ -241,6 +245,19 @@ impl<'a> Importer<'a> {
}
choice
}
/// Find the last `from __future__` import statement in the AST.
fn find_last_future_import(&self) -> Option<&'a AstImport> {
self.imports
.iter()
.take_while(|import| {
import
.stmt
.as_import_from_stmt()
.is_some_and(|import_from| import_from.module.as_deref() == Some("__future__"))
})
.last()
}
}
/// A map of symbols in scope at a particular location in a module.
@ -1293,6 +1310,44 @@ def foo():
");
}
#[test]
fn existing_future_import() {
let test = cursor_test(
"\
from __future__ import annotations
<CURSOR>
",
);
assert_snapshot!(
test.import("typing", "TypeVar"), @r"
from __future__ import annotations
import typing
typing.TypeVar
");
}
#[test]
fn existing_future_import_after_docstring() {
let test = cursor_test(
r#"
"This is a module level docstring"
from __future__ import annotations
<CURSOR>
"#,
);
assert_snapshot!(
test.import("typing", "TypeVar"), @r#"
"This is a module level docstring"
from __future__ import annotations
import typing
typing.TypeVar
"#);
}
#[test]
fn qualify_symbol_to_avoid_overwriting_other_symbol_in_scope() {
let test = cursor_test(