diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index a6dbce3426..17bf0533a1 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -314,6 +314,75 @@ a_person = {"name": "Alice", "age": 30, "extra": True} (a_person := {"name": "Alice", "age": 30, "extra": True}) ``` +## Union of `TypedDict` + +When assigning to a union of `TypedDict` types, the type will be narrowed based on the dictionary +literal: + +```py +from typing import TypedDict +from typing_extensions import NotRequired + +class Foo(TypedDict): + foo: int + +x1: Foo | None = {"foo": 1} +reveal_type(x1) # revealed: Foo + +class Bar(TypedDict): + bar: int + +x2: Foo | Bar = {"foo": 1} +reveal_type(x2) # revealed: Foo + +x3: Foo | Bar = {"bar": 1} +reveal_type(x3) # revealed: Bar + +x4: Foo | Bar | None = {"bar": 1} +reveal_type(x4) # revealed: Bar + +# error: [invalid-assignment] +x5: Foo | Bar = {"baz": 1} +reveal_type(x5) # revealed: Foo | Bar + +class FooBar1(TypedDict): + foo: int + bar: int + +class FooBar2(TypedDict): + foo: int + bar: int + +class FooBar3(TypedDict): + foo: int + bar: int + baz: NotRequired[int] + +x6: FooBar1 | FooBar2 = {"foo": 1, "bar": 1} +reveal_type(x6) # revealed: FooBar1 | FooBar2 + +x7: FooBar1 | FooBar3 = {"foo": 1, "bar": 1} +reveal_type(x7) # revealed: FooBar1 | FooBar3 + +x8: FooBar1 | FooBar2 | FooBar3 | None = {"foo": 1, "bar": 1} +reveal_type(x8) # revealed: FooBar1 | FooBar2 | FooBar3 +``` + +In doing so, may have to infer the same type with multiple distinct type contexts: + +```py +from typing import TypedDict + +class NestedFoo(TypedDict): + foo: list[FooBar1] + +class NestedBar(TypedDict): + foo: list[FooBar2] + +x1: NestedFoo | NestedBar = {"foo": [{"foo": 1, "bar": 1}]} +reveal_type(x1) # revealed: NestedFoo | NestedBar +``` + ## Type ignore compatibility issues Users should be able to ignore TypedDict validation errors with `# type: ignore` diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index d819726c18..4332542b81 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -8364,13 +8364,64 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut item_types = FxHashMap::default(); // Validate `TypedDict` dictionary literal assignments. - if let Some(tcx) = tcx.annotation - && let Some(typed_dict) = tcx - .filter_union(self.db(), Type::is_typed_dict) - .as_typed_dict() - && let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict, &mut item_types) - { - return ty; + if let Some(tcx) = tcx.annotation { + let tcx = tcx.filter_union(self.db(), Type::is_typed_dict); + + if let Some(typed_dict) = tcx.as_typed_dict() { + // If there is a single typed dict annotation, infer against it directly. + if let Some(ty) = + self.infer_typed_dict_expression(dict, typed_dict, &mut item_types) + { + return ty; + } + } else if let Type::Union(tcx) = tcx { + // Otherwise, disable diagnostics as we attempt to narrow to specific elements of the union. + let old_multi_inference = self.context.set_multi_inference(true); + let old_multi_inference_state = + self.set_multi_inference_state(MultiInferenceState::Ignore); + + let mut narrowed_typed_dicts = Vec::new(); + for element in tcx.elements(self.db()) { + let typed_dict = element + .as_typed_dict() + .expect("filtered out non-typed-dict types above"); + + if self + .infer_typed_dict_expression(dict, typed_dict, &mut item_types) + .is_some() + { + narrowed_typed_dicts.push(typed_dict); + } + + item_types.clear(); + } + + if !narrowed_typed_dicts.is_empty() { + // Now that we know which typed dict annotations are valid, re-infer with diagnostics enabled, + self.context.set_multi_inference(old_multi_inference); + + // We may have to infer the same expression multiple times with distinct type context, + // so we take the intersection of all valid inferences for a given expression. + self.set_multi_inference_state(MultiInferenceState::Intersect); + + let mut narrowed_tys = Vec::new(); + for typed_dict in narrowed_typed_dicts { + let mut item_types = FxHashMap::default(); + + let ty = self + .infer_typed_dict_expression(dict, typed_dict, &mut item_types) + .expect("ensured the typed dict is valid above"); + + narrowed_tys.push(ty); + } + + self.set_multi_inference_state(old_multi_inference_state); + return UnionType::from_elements(self.db(), narrowed_tys); + } + + self.context.set_multi_inference(old_multi_inference); + self.set_multi_inference_state(old_multi_inference_state); + } } // Avoid false positives for the functional `TypedDict` form, which is currently