diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 1577d3ea0c..91e957298c 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -2019,5 +2019,138 @@ static_assert(is_disjoint_from(TD, dict[str, int])) # error: [static-assert-err static_assert(is_disjoint_from(TD, dict[str, str])) # error: [static-assert-error] ``` +## Narrowing tagged unions of `TypedDict`s + +In a tagged union of `TypedDict`s, a common field in each member (often `"type"` or `"tag"`) is +given a distinct `Literal` type/value. We can narrow the union by constraining this field: + +```py +from typing import TypedDict, Literal + +class Foo(TypedDict): + tag: Literal["foo"] + +class Bar(TypedDict): + tag: Literal[42] + +class Baz(TypedDict): + tag: Literal[b"baz"] # `BytesLiteral` is supported. + +class Bing(TypedDict): + tag: Literal["bing"] + +def _(u: Foo | Bar | Baz | Bing): + if u["tag"] == "foo": + reveal_type(u) # revealed: Foo + elif u["tag"] == 42: + reveal_type(u) # revealed: Bar + elif u["tag"] == b"baz": + reveal_type(u) # revealed: Baz + else: + reveal_type(u) # revealed: Bing +``` + +We can descend into intersections to discover `TypedDict` types that need narrowing: + +```py +from collections.abc import Mapping +from ty_extensions import Intersection + +def _(u: Foo | Intersection[Bar, Mapping[str, int]]): + if u["tag"] == "foo": + reveal_type(u) # revealed: Foo + else: + reveal_type(u) # revealed: Bar & Mapping[str, int] +``` + +We can also narrow a single `TypedDict` type to `Never`: + +```py +def _(u: Foo): + if u["tag"] == "foo": + reveal_type(u) # revealed: Foo + else: + reveal_type(u) # revealed: Never +``` + +Narrowing is restricted to `Literal` tags, though, because `x == "foo"` doesn't generally tell us +anything about the type of `x`. Here's an example where narrowing would be tempting but unsound: + +```py +from ty_extensions import is_assignable_to, static_assert + +class NonLiteralTD(TypedDict): + tag: int + +def _(u: Foo | NonLiteralTD): + if u["tag"] == "foo": + # We can't narrow the union here... + reveal_type(u) # revealed: Foo | NonLiteralTD + else: + # ...(even though we can here)... + reveal_type(u) # revealed: NonLiteralTD + +# ...because `NonLiteralTD["tag"]` could be assigned to with one of these, which would make the +# first condition above true at runtime! +class WackyInt(int): + def __eq__(self, other): + return True + +_: NonLiteralTD = {"tag": WackyInt(99)} # allowed +``` + +We can still narrow `Literal` tags even when non-`TypedDict` types are present in the union: + +```py +def _(u: Foo | Bar | dict): + if u["tag"] == "foo": + # TODO: `dict & ~` should simplify to `dict` here, but that's currently a + # false negative in `is_disjoint_impl`. + reveal_type(u) # revealed: Foo | (dict[Unknown, Unknown] & ~) + +# The negation(s) will simplify out if we add something to the union that doesn't inherit from +# `dict`. It just needs to support indexing with a string key. +class NotADict: + def __getitem__(self, key): ... + +def _(u: Foo | Bar | NotADict): + if u["tag"] == 42: + reveal_type(u) # revealed: Bar | NotADict +``` + +It would be nice if we could also narrow `TypedDict` unions by checking whether a key (which only +shows up in a subset of the union members) is present, but that isn't generally correct, because +"extra items" are allowed by default. For example, even though `Bar` here doesn't define a `"foo"` +field, it could be *assigned to* with another `TypedDict` that does: + +```py +class Foo(TypedDict): + foo: int + +class Bar(TypedDict): + bar: int + +def disappointment(u: Foo | Bar): + if "foo" in u: + # We can't narrow the union here... + reveal_type(u) # revealed: Foo | Bar + else: + # ...(even though we *can* narrow it here)... + # TODO: This should narrow to `Bar`, because "foo" is required in `Foo`. + reveal_type(u) # revealed: Foo | Bar + +# ...because `u` could turn out to be one of these. +class FooBar(TypedDict): + foo: int + bar: int + +static_assert(is_assignable_to(FooBar, Foo)) +static_assert(is_assignable_to(FooBar, Bar)) +``` + +TODO: The narrowing that we didn't do above will become possible when we add support for +`closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature. + +[closed]: https://peps.python.org/pep-0728/#disallowing-extra-items-explicitly [subtyping section]: https://typing.python.org/en/latest/spec/typeddict.html#subtyping-between-typeddict-types [`typeddict`]: https://typing.python.org/en/latest/spec/typeddict.html diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index ae36ea47ed..1ea3c28590 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -10,6 +10,9 @@ use crate::semantic_index::scope::ScopeId; use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::function::KnownFunction; use crate::types::infer::infer_same_file_expression_type; +use crate::types::typed_dict::{ + SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType, +}; use crate::types::{ CallableType, ClassLiteral, ClassType, IntersectionBuilder, KnownClass, KnownInstanceType, SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext, @@ -17,6 +20,7 @@ use crate::types::{ }; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; +use ruff_python_ast::name::Name; use ruff_python_stdlib::identifiers::is_identifier; use itertools::Itertools; @@ -877,6 +881,72 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); let mut constraints = NarrowingConstraints::default(); + // Narrow tagged unions of `TypedDict`s with `Literal` keys, for example: + // + // class Foo(TypedDict): + // tag: Literal["foo"] + // class Bar(TypedDict): + // tag: Literal["bar"] + // def _(union: Foo | Bar): + // if union["tag"] == "foo": + // reveal_type(union) # Foo + // + // Importantly, `my_typeddict_union["tag"]` isn't the place we're going to constraint. + // Instead, we're going to constrain `my_typeddict_union` itself. + if matches!(&**ops, [ast::CmpOp::Eq | ast::CmpOp::NotEq]) + && let ast::Expr::Subscript(subscript) = &**left + && let lhs_value_type = inference.expression_type(&*subscript.value) + // Checking for `TypedDict`s up front isn't strictly necessary, since the intersection + // we're going to build is compatible with non-`TypedDict` types, but we don't want to + // do the work to build it and intersect it (or for that matter, let the user see it) + // in the common case where there are no `TypedDict`s. + && is_typeddict_or_union_with_typeddicts(self.db, lhs_value_type) + && let Some(subscript_place_expr) = place_expr(&subscript.value) + && let Type::StringLiteral(key_literal) = inference.expression_type(&*subscript.slice) + && let rhs_type = inference.expression_type(&comparators[0]) + && is_supported_typeddict_tag_literal(rhs_type) + { + // If we have an equality constraint (either `==` on the `if` side, or `!=` on the + // `else` side), we have to be careful. If all the matching fields in all the + // `TypedDict`s here have literal types, then yes, equality is as good as a type check. + // However, if any of them are e.g. `int` or `str` or some random class, then we can't + // narrow their type at all, because subclasses of those types can implement `__eq__` + // in any perverse way they like. On the other hand, if this is an *inequality* + // constraint, then we can go ahead and assert "you can't be this exact literal type" + // without worrying about what other types might be present. + let constrain_with_equality = is_positive == (ops[0] == ast::CmpOp::Eq); + if !constrain_with_equality + || all_matching_typeddict_fields_have_literal_types( + self.db, + lhs_value_type, + key_literal.value(self.db), + ) + { + let field_name = Name::from(key_literal.value(self.db)); + let rhs_type = inference.expression_type(&comparators[0]); + // To avoid excluding non-`TypedDict` types, our constraints are always expressed + // as a negative intersection (i.e. "you're *not* this kind of `TypedDict`"). If + // `constrain_with_equality` is true, the whole constraint is going to be a double + // negative, i.e. "you're *not* a `TypedDict` *without* this literal field". As the + // first step of building that, we negate the right hand side. + let field_type = rhs_type.negate_if(self.db, constrain_with_equality); + // Create the synthesized `TypedDict` with that (possibly negated) field. We don't + // want to constrain the mutability or required-ness of the field, so the most + // compatible form is not-required and read-only. + let field = TypedDictFieldBuilder::new(field_type) + .required(false) + .read_only(true) + .build(); + let schema = TypedDictSchema::from_iter([(field_name, field)]); + let synthesized_typeddict = + TypedDictType::Synthesized(SynthesizedTypedDictType::new(self.db, schema)); + // As mentioned above, the synthesized `TypedDict` is always negated. + let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db); + let place = self.expect_place(&subscript_place_expr); + constraints.insert(place, intersection); + } + } + let mut last_rhs_ty: Option = None; for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { @@ -1212,3 +1282,71 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } } + +// Return true if the given type is a `TypedDict`, or if it's a union that includes at least one +// `TypedDict` (even if other types are present). +fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { + match ty { + Type::TypedDict(_) => true, + Type::Union(union) => { + union + .elements(db) + .iter() + .any(|union_member_ty| match union_member_ty { + Type::TypedDict(_) => true, + Type::Intersection(intersection) => { + intersection.positive(db).iter().any(Type::is_typed_dict) + } + _ => false, + }) + } + _ => false, + } +} + +fn is_supported_typeddict_tag_literal(ty: Type) -> bool { + matches!( + ty, + // TODO: We'd like to support `EnumLiteral` also, but we have to be careful with types like + // `IntEnum` and `StrEnum` that have custom `__eq__` methods. + Type::StringLiteral(_) | Type::BytesLiteral(_) | Type::IntLiteral(_) + ) +} + +// See the comment above the call to this function. +fn all_matching_typeddict_fields_have_literal_types<'db>( + db: &'db dyn Db, + ty: Type<'db>, + field_name: &str, +) -> bool { + let matching_field_is_literal = |typeddict: &TypedDictType<'db>| { + // There's no matching field to check if `.get()` returns `None`. + typeddict + .items(db) + .get(field_name) + .is_none_or(|field| is_supported_typeddict_tag_literal(field.declared_ty)) + }; + + match ty { + Type::TypedDict(td) => matching_field_is_literal(&td), + Type::Union(union) => { + union + .elements(db) + .iter() + .all(|union_member_ty| match union_member_ty { + Type::TypedDict(td) => matching_field_is_literal(td), + Type::Intersection(intersection) => { + intersection + .positive(db) + .iter() + .all(|intersection_member_ty| match intersection_member_ty { + Type::TypedDict(td) => matching_field_is_literal(td), + _ => true, + }) + } + _ => true, + }) + } + _ => true, + } +}