From fd86e699b524bd662843776478ce706696fca28b Mon Sep 17 00:00:00 2001 From: Felix Scherz Date: Sat, 3 Jan 2026 14:12:57 +0100 Subject: [PATCH] [ty] narrow `TypedDict` unions with `not in` (#22349) Co-authored-by: Alex Waygood --- .../resources/mdtest/typed_dict.md | 43 ++++++++- crates/ty_python_semantic/src/types/narrow.rs | 88 ++++++++++++++++--- 2 files changed, 117 insertions(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 1fba53f803..7eddc51a62 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -2124,20 +2124,26 @@ shows up in a subset of the union members) is present, but that isn't generally field, it could be *assigned to* with another `TypedDict` that does: ```py +from typing_extensions import Literal + class Foo(TypedDict): foo: int class Bar(TypedDict): bar: int -def disappointment(u: Foo | Bar): +def disappointment(u: Foo | Bar, v: Literal["foo"]): if "foo" in u: # We can't narrow the union here... reveal_type(u) # revealed: Foo | Bar else: # ...(even though we *can* narrow it here)... - # TODO: This should narrow to `Bar`, because "foo" is required in `Foo`. + reveal_type(u) # revealed: Bar + + if v in u: reveal_type(u) # revealed: Foo | Bar + else: + reveal_type(u) # revealed: Bar # ...because `u` could turn out to be one of these. class FooBar(TypedDict): @@ -2148,6 +2154,39 @@ static_assert(is_assignable_to(FooBar, Foo)) static_assert(is_assignable_to(FooBar, Bar)) ``` +`not in` works in the opposite way to `in`: we can narrow in the positive case, but we cannot narrow +in the negative case. The following snippet also tests our narrowing behaviour for intersections +that contain `TypedDict`s, and unions that contain intersections that contain `TypedDict`s: + +```py +from typing_extensions import Literal, Any +from ty_extensions import Intersection, is_assignable_to, static_assert + +def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Literal["bar"]): + reveal_type(u) # revealed: Foo | (Bar & Any) + reveal_type(v) # revealed: Bar & Any + + if "bar" not in t: + reveal_type(t) # revealed: Never + else: + reveal_type(t) # revealed: Bar + + if "bar" not in u: + reveal_type(u) # revealed: Foo + else: + reveal_type(u) # revealed: Foo | (Bar & Any) + + if "bar" not in v: + reveal_type(v) # revealed: Never + else: + reveal_type(v) # revealed: Bar & Any + + if w not in u: + reveal_type(u) # revealed: Foo + else: + reveal_type(u) # revealed: Foo | (Bar & Any) +``` + TODO: The narrowing that we didn't do above will become possible when we add support for `closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature. diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 26da111527..88958752c2 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -12,7 +12,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::function::KnownFunction; use crate::types::infer::{ExpressionInference, infer_same_file_expression_type}; use crate::types::typed_dict::{ - SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType, + SynthesizedTypedDictType, TypedDictField, TypedDictFieldBuilder, TypedDictSchema, TypedDictType, }; use crate::types::{ CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass, @@ -1099,6 +1099,75 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } + // Narrow unions and intersections of `TypedDict` in cases where required keys are + // excluded: + // + // class Foo(TypedDict): + // foo: int + // class Bar(TypedDict): + // bar: int + // + // def _(u: Foo | Bar): + // if "foo" not in u: + // reveal_type(u) # revealed: Bar + if matches!(&**ops, [ast::CmpOp::In | ast::CmpOp::NotIn]) + && let Type::StringLiteral(key) = inference.expression_type(&**left) + && let Some(rhs_place_expr) = place_expr(&comparators[0]) + && let rhs_type = inference.expression_type(&comparators[0]) + && is_typeddict_or_union_with_typeddicts(self.db, rhs_type) + { + let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn); + if is_negative_check { + let requires_key = |td: TypedDictType<'db>| -> bool { + td.items(self.db) + .get(key.value(self.db)) + .is_some_and(TypedDictField::is_required) + }; + + let narrowed = match rhs_type { + Type::TypedDict(td) => { + if requires_key(td) { + Type::Never + } else { + rhs_type + } + } + Type::Intersection(intersection) => { + if intersection + .positive(self.db) + .iter() + .copied() + .filter_map(Type::as_typed_dict) + .any(requires_key) + { + Type::Never + } else { + rhs_type + } + } + Type::Union(union) => { + // remove all members of the union that would require the key + union.filter(self.db, |ty| match ty { + Type::TypedDict(td) => !requires_key(*td), + Type::Intersection(intersection) => !intersection + .positive(self.db) + .iter() + .copied() + .filter_map(Type::as_typed_dict) + .any(requires_key), + _ => true, + }) + } + _ => rhs_type, + }; + + if narrowed != rhs_type { + let place = self.expect_place(&rhs_place_expr); + constraints.insert(place, NarrowingConstraint::typeguard(narrowed)); + } + } + } + let mut last_rhs_ty: Option = None; for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { @@ -1677,18 +1746,13 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { match ty { Type::TypedDict(_) => true, - Type::Union(union) => { - union - .elements(db) - .iter() - .any(|union_member_ty| match union_member_ty { - Type::TypedDict(_) => true, - Type::Intersection(intersection) => { - intersection.positive(db).iter().any(Type::is_typed_dict) - } - _ => false, - }) + Type::Intersection(intersection) => { + intersection.positive(db).iter().any(Type::is_typed_dict) } + Type::Union(union) => union + .elements(db) + .iter() + .any(|union_member_ty| is_typeddict_or_union_with_typeddicts(db, *union_member_ty)), _ => false, } }