[ty] narrow TypedDict unions with not in (#22349)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Felix Scherz
2026-01-03 14:12:57 +01:00
committed by GitHub
parent d0f841bff2
commit fd86e699b5
2 changed files with 117 additions and 14 deletions

View File

@@ -2124,20 +2124,26 @@ shows up in a subset of the union members) is present, but that isn't generally
field, it could be *assigned to* with another `TypedDict` that does:
```py
from typing_extensions import Literal
class Foo(TypedDict):
foo: int
class Bar(TypedDict):
bar: int
def disappointment(u: Foo | Bar):
def disappointment(u: Foo | Bar, v: Literal["foo"]):
if "foo" in u:
# We can't narrow the union here...
reveal_type(u) # revealed: Foo | Bar
else:
# ...(even though we *can* narrow it here)...
# TODO: This should narrow to `Bar`, because "foo" is required in `Foo`.
reveal_type(u) # revealed: Bar
if v in u:
reveal_type(u) # revealed: Foo | Bar
else:
reveal_type(u) # revealed: Bar
# ...because `u` could turn out to be one of these.
class FooBar(TypedDict):
@@ -2148,6 +2154,39 @@ static_assert(is_assignable_to(FooBar, Foo))
static_assert(is_assignable_to(FooBar, Bar))
```
`not in` works in the opposite way to `in`: we can narrow in the positive case, but we cannot narrow
in the negative case. The following snippet also tests our narrowing behaviour for intersections
that contain `TypedDict`s, and unions that contain intersections that contain `TypedDict`s:
```py
from typing_extensions import Literal, Any
from ty_extensions import Intersection, is_assignable_to, static_assert
def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Literal["bar"]):
reveal_type(u) # revealed: Foo | (Bar & Any)
reveal_type(v) # revealed: Bar & Any
if "bar" not in t:
reveal_type(t) # revealed: Never
else:
reveal_type(t) # revealed: Bar
if "bar" not in u:
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Foo | (Bar & Any)
if "bar" not in v:
reveal_type(v) # revealed: Never
else:
reveal_type(v) # revealed: Bar & Any
if w not in u:
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Foo | (Bar & Any)
```
TODO: The narrowing that we didn't do above will become possible when we add support for
`closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature.

View File

@@ -12,7 +12,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::function::KnownFunction;
use crate::types::infer::{ExpressionInference, infer_same_file_expression_type};
use crate::types::typed_dict::{
SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
SynthesizedTypedDictType, TypedDictField, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
};
use crate::types::{
CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass,
@@ -1099,6 +1099,75 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}
// Narrow unions and intersections of `TypedDict` in cases where required keys are
// excluded:
//
// class Foo(TypedDict):
// foo: int
// class Bar(TypedDict):
// bar: int
//
// def _(u: Foo | Bar):
// if "foo" not in u:
// reveal_type(u) # revealed: Bar
if matches!(&**ops, [ast::CmpOp::In | ast::CmpOp::NotIn])
&& let Type::StringLiteral(key) = inference.expression_type(&**left)
&& let Some(rhs_place_expr) = place_expr(&comparators[0])
&& let rhs_type = inference.expression_type(&comparators[0])
&& is_typeddict_or_union_with_typeddicts(self.db, rhs_type)
{
let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn);
if is_negative_check {
let requires_key = |td: TypedDictType<'db>| -> bool {
td.items(self.db)
.get(key.value(self.db))
.is_some_and(TypedDictField::is_required)
};
let narrowed = match rhs_type {
Type::TypedDict(td) => {
if requires_key(td) {
Type::Never
} else {
rhs_type
}
}
Type::Intersection(intersection) => {
if intersection
.positive(self.db)
.iter()
.copied()
.filter_map(Type::as_typed_dict)
.any(requires_key)
{
Type::Never
} else {
rhs_type
}
}
Type::Union(union) => {
// remove all members of the union that would require the key
union.filter(self.db, |ty| match ty {
Type::TypedDict(td) => !requires_key(*td),
Type::Intersection(intersection) => !intersection
.positive(self.db)
.iter()
.copied()
.filter_map(Type::as_typed_dict)
.any(requires_key),
_ => true,
})
}
_ => rhs_type,
};
if narrowed != rhs_type {
let place = self.expect_place(&rhs_place_expr);
constraints.insert(place, NarrowingConstraint::typeguard(narrowed));
}
}
}
let mut last_rhs_ty: Option<Type> = None;
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
@@ -1677,18 +1746,13 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
match ty {
Type::TypedDict(_) => true,
Type::Union(union) => {
union
.elements(db)
.iter()
.any(|union_member_ty| match union_member_ty {
Type::TypedDict(_) => true,
Type::Intersection(intersection) => {
intersection.positive(db).iter().any(Type::is_typed_dict)
}
_ => false,
})
Type::Intersection(intersection) => {
intersection.positive(db).iter().any(Type::is_typed_dict)
}
Type::Union(union) => union
.elements(db)
.iter()
.any(|union_member_ty| is_typeddict_or_union_with_typeddicts(db, *union_member_ty)),
_ => false,
}
}