[ty] narrow tagged unions of TypedDict (#22104)

Identify and narrow cases like this:

```py
class Foo(TypedDict):
    tag: Literal["foo"]

class Bar(TypedDict):
    tag: Literal["bar"]

def _(union: Foo | Bar):
    if union["tag"] == "foo":
        reveal_type(union)  # Foo
```

Fixes part of https://github.com/astral-sh/ty/issues/1479.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Jack O'Connor
2025-12-23 11:30:08 -08:00
committed by GitHub
parent 4c175fa0e1
commit e245c1d76e
2 changed files with 271 additions and 0 deletions

View File

@@ -2019,5 +2019,138 @@ static_assert(is_disjoint_from(TD, dict[str, int])) # error: [static-assert-err
static_assert(is_disjoint_from(TD, dict[str, str])) # error: [static-assert-error]
```
## Narrowing tagged unions of `TypedDict`s
In a tagged union of `TypedDict`s, a common field in each member (often `"type"` or `"tag"`) is
given a distinct `Literal` type/value. We can narrow the union by constraining this field:
```py
from typing import TypedDict, Literal
class Foo(TypedDict):
tag: Literal["foo"]
class Bar(TypedDict):
tag: Literal[42]
class Baz(TypedDict):
tag: Literal[b"baz"] # `BytesLiteral` is supported.
class Bing(TypedDict):
tag: Literal["bing"]
def _(u: Foo | Bar | Baz | Bing):
if u["tag"] == "foo":
reveal_type(u) # revealed: Foo
elif u["tag"] == 42:
reveal_type(u) # revealed: Bar
elif u["tag"] == b"baz":
reveal_type(u) # revealed: Baz
else:
reveal_type(u) # revealed: Bing
```
We can descend into intersections to discover `TypedDict` types that need narrowing:
```py
from collections.abc import Mapping
from ty_extensions import Intersection
def _(u: Foo | Intersection[Bar, Mapping[str, int]]):
if u["tag"] == "foo":
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Bar & Mapping[str, int]
```
We can also narrow a single `TypedDict` type to `Never`:
```py
def _(u: Foo):
if u["tag"] == "foo":
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Never
```
Narrowing is restricted to `Literal` tags, though, because `x == "foo"` doesn't generally tell us
anything about the type of `x`. Here's an example where narrowing would be tempting but unsound:
```py
from ty_extensions import is_assignable_to, static_assert
class NonLiteralTD(TypedDict):
tag: int
def _(u: Foo | NonLiteralTD):
if u["tag"] == "foo":
# We can't narrow the union here...
reveal_type(u) # revealed: Foo | NonLiteralTD
else:
# ...(even though we can here)...
reveal_type(u) # revealed: NonLiteralTD
# ...because `NonLiteralTD["tag"]` could be assigned to with one of these, which would make the
# first condition above true at runtime!
class WackyInt(int):
def __eq__(self, other):
return True
_: NonLiteralTD = {"tag": WackyInt(99)} # allowed
```
We can still narrow `Literal` tags even when non-`TypedDict` types are present in the union:
```py
def _(u: Foo | Bar | dict):
if u["tag"] == "foo":
# TODO: `dict & ~<TypedDict ...>` should simplify to `dict` here, but that's currently a
# false negative in `is_disjoint_impl`.
reveal_type(u) # revealed: Foo | (dict[Unknown, Unknown] & ~<TypedDict with items 'tag'>)
# The negation(s) will simplify out if we add something to the union that doesn't inherit from
# `dict`. It just needs to support indexing with a string key.
class NotADict:
def __getitem__(self, key): ...
def _(u: Foo | Bar | NotADict):
if u["tag"] == 42:
reveal_type(u) # revealed: Bar | NotADict
```
It would be nice if we could also narrow `TypedDict` unions by checking whether a key (which only
shows up in a subset of the union members) is present, but that isn't generally correct, because
"extra items" are allowed by default. For example, even though `Bar` here doesn't define a `"foo"`
field, it could be *assigned to* with another `TypedDict` that does:
```py
class Foo(TypedDict):
foo: int
class Bar(TypedDict):
bar: int
def disappointment(u: Foo | Bar):
if "foo" in u:
# We can't narrow the union here...
reveal_type(u) # revealed: Foo | Bar
else:
# ...(even though we *can* narrow it here)...
# TODO: This should narrow to `Bar`, because "foo" is required in `Foo`.
reveal_type(u) # revealed: Foo | Bar
# ...because `u` could turn out to be one of these.
class FooBar(TypedDict):
foo: int
bar: int
static_assert(is_assignable_to(FooBar, Foo))
static_assert(is_assignable_to(FooBar, Bar))
```
TODO: The narrowing that we didn't do above will become possible when we add support for
`closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature.
[closed]: https://peps.python.org/pep-0728/#disallowing-extra-items-explicitly
[subtyping section]: https://typing.python.org/en/latest/spec/typeddict.html#subtyping-between-typeddict-types
[`typeddict`]: https://typing.python.org/en/latest/spec/typeddict.html

View File

@@ -10,6 +10,9 @@ use crate::semantic_index::scope::ScopeId;
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::typed_dict::{
SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
};
use crate::types::{
CallableType, ClassLiteral, ClassType, IntersectionBuilder, KnownClass, KnownInstanceType,
SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext,
@@ -17,6 +20,7 @@ use crate::types::{
};
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast::name::Name;
use ruff_python_stdlib::identifiers::is_identifier;
use itertools::Itertools;
@@ -877,6 +881,72 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
let mut constraints = NarrowingConstraints::default();
// Narrow tagged unions of `TypedDict`s with `Literal` keys, for example:
//
// class Foo(TypedDict):
// tag: Literal["foo"]
// class Bar(TypedDict):
// tag: Literal["bar"]
// def _(union: Foo | Bar):
// if union["tag"] == "foo":
// reveal_type(union) # Foo
//
// Importantly, `my_typeddict_union["tag"]` isn't the place we're going to constraint.
// Instead, we're going to constrain `my_typeddict_union` itself.
if matches!(&**ops, [ast::CmpOp::Eq | ast::CmpOp::NotEq])
&& let ast::Expr::Subscript(subscript) = &**left
&& let lhs_value_type = inference.expression_type(&*subscript.value)
// Checking for `TypedDict`s up front isn't strictly necessary, since the intersection
// we're going to build is compatible with non-`TypedDict` types, but we don't want to
// do the work to build it and intersect it (or for that matter, let the user see it)
// in the common case where there are no `TypedDict`s.
&& is_typeddict_or_union_with_typeddicts(self.db, lhs_value_type)
&& let Some(subscript_place_expr) = place_expr(&subscript.value)
&& let Type::StringLiteral(key_literal) = inference.expression_type(&*subscript.slice)
&& let rhs_type = inference.expression_type(&comparators[0])
&& is_supported_typeddict_tag_literal(rhs_type)
{
// If we have an equality constraint (either `==` on the `if` side, or `!=` on the
// `else` side), we have to be careful. If all the matching fields in all the
// `TypedDict`s here have literal types, then yes, equality is as good as a type check.
// However, if any of them are e.g. `int` or `str` or some random class, then we can't
// narrow their type at all, because subclasses of those types can implement `__eq__`
// in any perverse way they like. On the other hand, if this is an *inequality*
// constraint, then we can go ahead and assert "you can't be this exact literal type"
// without worrying about what other types might be present.
let constrain_with_equality = is_positive == (ops[0] == ast::CmpOp::Eq);
if !constrain_with_equality
|| all_matching_typeddict_fields_have_literal_types(
self.db,
lhs_value_type,
key_literal.value(self.db),
)
{
let field_name = Name::from(key_literal.value(self.db));
let rhs_type = inference.expression_type(&comparators[0]);
// To avoid excluding non-`TypedDict` types, our constraints are always expressed
// as a negative intersection (i.e. "you're *not* this kind of `TypedDict`"). If
// `constrain_with_equality` is true, the whole constraint is going to be a double
// negative, i.e. "you're *not* a `TypedDict` *without* this literal field". As the
// first step of building that, we negate the right hand side.
let field_type = rhs_type.negate_if(self.db, constrain_with_equality);
// Create the synthesized `TypedDict` with that (possibly negated) field. We don't
// want to constrain the mutability or required-ness of the field, so the most
// compatible form is not-required and read-only.
let field = TypedDictFieldBuilder::new(field_type)
.required(false)
.read_only(true)
.build();
let schema = TypedDictSchema::from_iter([(field_name, field)]);
let synthesized_typeddict =
TypedDictType::Synthesized(SynthesizedTypedDictType::new(self.db, schema));
// As mentioned above, the synthesized `TypedDict` is always negated.
let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db);
let place = self.expect_place(&subscript_place_expr);
constraints.insert(place, intersection);
}
}
let mut last_rhs_ty: Option<Type> = None;
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
@@ -1212,3 +1282,71 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}
}
// Return true if the given type is a `TypedDict`, or if it's a union that includes at least one
// `TypedDict` (even if other types are present).
fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
match ty {
Type::TypedDict(_) => true,
Type::Union(union) => {
union
.elements(db)
.iter()
.any(|union_member_ty| match union_member_ty {
Type::TypedDict(_) => true,
Type::Intersection(intersection) => {
intersection.positive(db).iter().any(Type::is_typed_dict)
}
_ => false,
})
}
_ => false,
}
}
fn is_supported_typeddict_tag_literal(ty: Type) -> bool {
matches!(
ty,
// TODO: We'd like to support `EnumLiteral` also, but we have to be careful with types like
// `IntEnum` and `StrEnum` that have custom `__eq__` methods.
Type::StringLiteral(_) | Type::BytesLiteral(_) | Type::IntLiteral(_)
)
}
// See the comment above the call to this function.
fn all_matching_typeddict_fields_have_literal_types<'db>(
db: &'db dyn Db,
ty: Type<'db>,
field_name: &str,
) -> bool {
let matching_field_is_literal = |typeddict: &TypedDictType<'db>| {
// There's no matching field to check if `.get()` returns `None`.
typeddict
.items(db)
.get(field_name)
.is_none_or(|field| is_supported_typeddict_tag_literal(field.declared_ty))
};
match ty {
Type::TypedDict(td) => matching_field_is_literal(&td),
Type::Union(union) => {
union
.elements(db)
.iter()
.all(|union_member_ty| match union_member_ty {
Type::TypedDict(td) => matching_field_is_literal(td),
Type::Intersection(intersection) => {
intersection
.positive(db)
.iter()
.all(|intersection_member_ty| match intersection_member_ty {
Type::TypedDict(td) => matching_field_is_literal(td),
_ => true,
})
}
_ => true,
})
}
_ => true,
}
}