diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index 0408d972be..e58a9c7049 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -6613,6 +6613,27 @@ def f(zqzqzq: str): ); } + #[test] + fn auto_import_prioritizes_reusing_import_from_statements() { + let builder = completion_test_builder( + "\ +import typing +from typing import Callable +TypedDi +", + ); + assert_snapshot!( + builder.imports().build().snapshot(), + @r" + typing.TypedDict :: + typing.is_typeddict :: + _FilterConfigurationTypedDict :: from logging.config import _FilterConfigurationTypedDict + + _FormatterConfigurationTypedDict :: from logging.config import _FormatterConfigurationTypedDict + ", + ); + } + /// A way to create a simple single-file (named `main.py`) completion test /// builder. /// @@ -6638,6 +6659,7 @@ def f(zqzqzq: str): skip_builtins: bool, skip_keywords: bool, type_signatures: bool, + imports: bool, module_names: bool, // This doesn't seem like a "very complex" type to me... ---AG #[allow(clippy::type_complexity)] @@ -6670,6 +6692,7 @@ def f(zqzqzq: str): original, filtered, type_signatures: self.type_signatures, + imports: self.imports, module_names: self.module_names, } } @@ -6730,6 +6753,15 @@ def f(zqzqzq: str): self } + /// When set, include the import associated with the + /// completion. + /// + /// Not enabled by default. + fn imports(mut self) -> CompletionTestBuilder { + self.imports = true; + self + } + /// When set, the module name for each symbol is included /// in the snapshot (if available). fn module_names(mut self) -> CompletionTestBuilder { @@ -6762,6 +6794,9 @@ def f(zqzqzq: str): /// Whether type signatures should be included in the snapshot /// generated by `CompletionTest::snapshot`. type_signatures: bool, + /// Whether to show the import that will be inserted when this + /// completion is selected. + imports: bool, /// Whether module names should be included in the snapshot /// generated by `CompletionTest::snapshot`. module_names: bool, @@ -6797,6 +6832,17 @@ def f(zqzqzq: str): .unwrap_or(""); snapshot = format!("{snapshot} :: {module_name}"); } + if self.imports { + if let Some(ref edit) = c.import { + if let Some(import) = edit.content() { + snapshot = format!("{snapshot} :: {import}"); + } else { + snapshot = format!("{snapshot} :: "); + } + } else { + snapshot = format!("{snapshot} :: "); + } + } snapshot }) .collect::>() @@ -6845,6 +6891,7 @@ def f(zqzqzq: str): skip_builtins: false, skip_keywords: false, type_signatures: false, + imports: false, module_names: false, predicate: None, }