diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 961c542ee1..1fba53f803 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -2151,6 +2151,82 @@ static_assert(is_assignable_to(FooBar, Bar)) TODO: The narrowing that we didn't do above will become possible when we add support for `closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature. +## Narrowing tagged unions of `TypedDict`s with `match` statements + +Just like with `if` statements, we can narrow tagged unions of `TypedDict`s in `match` statements: + +```toml +[environment] +python-version = "3.10" +``` + +```py +from typing import TypedDict, Literal + +class Foo(TypedDict): + tag: Literal["foo"] + +class Bar(TypedDict): + tag: Literal[42] + +class Baz(TypedDict): + tag: Literal[b"baz"] + +class Bing(TypedDict): + tag: Literal["bing"] + +def match_statements(u: Foo | Bar | Baz | Bing): + match u["tag"]: + case "foo": + reveal_type(u) # revealed: Foo + case 42: + reveal_type(u) # revealed: Bar + case b"baz": + reveal_type(u) # revealed: Baz + case _: + reveal_type(u) # revealed: Bing +``` + +We can also narrow a single `TypedDict` type to `Never`: + +```py +def match_single(u: Foo): + match u["tag"]: + case "foo": + reveal_type(u) # revealed: Foo + case _: + reveal_type(u) # revealed: Never +``` + +Narrowing is restricted to `Literal` tags: + +```py +from ty_extensions import is_assignable_to, static_assert + +class NonLiteralTD(TypedDict): + tag: int + +def match_non_literal(u: Foo | NonLiteralTD): + match u["tag"]: + case "foo": + # We can't narrow the union here... + reveal_type(u) # revealed: Foo | NonLiteralTD + case _: + # ...(but we *can* narrow here)... + reveal_type(u) # revealed: NonLiteralTD +``` + +We can still narrow `Literal` tags even when non-`TypedDict` types are present in the union: + +```py +def match_with_dict(u: Foo | Bar | dict): + match u["tag"]: + case "foo": + # TODO: `dict & ~` should simplify to `dict` here, but that's currently a + # false negative in `is_disjoint_impl`. + reveal_type(u) # revealed: Foo | (dict[Unknown, Unknown] & ~) +``` + [closed]: https://peps.python.org/pep-0728/#disallowing-extra-items-explicitly [subtyping section]: https://typing.python.org/en/latest/spec/typeddict.html#subtyping-between-typeddict-types [`typeddict`]: https://typing.python.org/en/latest/spec/typeddict.html diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index e85792252a..e1c840a008 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1078,55 +1078,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // Instead, we're going to constrain `my_typeddict_union` itself. if matches!(&**ops, [ast::CmpOp::Eq | ast::CmpOp::NotEq]) && let ast::Expr::Subscript(subscript) = &**left - && let lhs_value_type = inference.expression_type(&*subscript.value) - // Checking for `TypedDict`s up front isn't strictly necessary, since the intersection - // we're going to build is compatible with non-`TypedDict` types, but we don't want to - // do the work to build it and intersect it (or for that matter, let the user see it) - // in the common case where there are no `TypedDict`s. - && is_typeddict_or_union_with_typeddicts(self.db, lhs_value_type) - && let Some(subscript_place_expr) = place_expr(&subscript.value) - && let Type::StringLiteral(key_literal) = inference.expression_type(&*subscript.slice) - && let rhs_type = inference.expression_type(&comparators[0]) - && is_supported_typeddict_tag_literal(rhs_type) { - // If we have an equality constraint (either `==` on the `if` side, or `!=` on the - // `else` side), we have to be careful. If all the matching fields in all the - // `TypedDict`s here have literal types, then yes, equality is as good as a type check. - // However, if any of them are e.g. `int` or `str` or some random class, then we can't - // narrow their type at all, because subclasses of those types can implement `__eq__` - // in any perverse way they like. On the other hand, if this is an *inequality* - // constraint, then we can go ahead and assert "you can't be this exact literal type" - // without worrying about what other types might be present. + // For `==`, we use equality semantics on the `if` branch (is_positive=true). + // For `!=`, we use equality semantics on the `else` branch (is_positive=false). let constrain_with_equality = is_positive == (ops[0] == ast::CmpOp::Eq); - if !constrain_with_equality - || all_matching_typeddict_fields_have_literal_types( - self.db, - lhs_value_type, - key_literal.value(self.db), - ) - { - let field_name = Name::from(key_literal.value(self.db)); - let rhs_type = inference.expression_type(&comparators[0]); - // To avoid excluding non-`TypedDict` types, our constraints are always expressed - // as a negative intersection (i.e. "you're *not* this kind of `TypedDict`"). If - // `constrain_with_equality` is true, the whole constraint is going to be a double - // negative, i.e. "you're *not* a `TypedDict` *without* this literal field". As the - // first step of building that, we negate the right hand side. - let field_type = rhs_type.negate_if(self.db, constrain_with_equality); - // Create the synthesized `TypedDict` with that (possibly negated) field. We don't - // want to constrain the mutability or required-ness of the field, so the most - // compatible form is not-required and read-only. - let field = TypedDictFieldBuilder::new(field_type) - .required(false) - .read_only(true) - .build(); - let schema = TypedDictSchema::from_iter([(field_name, field)]); - let synthesized_typeddict = - TypedDictType::Synthesized(SynthesizedTypedDictType::new(self.db, schema)); - // As mentioned above, the synthesized `TypedDict` is always negated. - let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db); - let place = self.expect_place(&subscript_place_expr); - constraints.insert(place, NarrowingConstraint::regular(intersection)); + if let Some((place, constraint)) = self.narrow_typeddict_subscript( + inference.expression_type(&*subscript.value), + &subscript.value, + inference.expression_type(&*subscript.slice), + inference.expression_type(&comparators[0]), + constrain_with_equality, + ) { + constraints.insert(place, constraint); } } @@ -1413,8 +1376,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { value: Expression<'db>, is_positive: bool, ) -> Option> { + let subject_node = subject.node_ref(self.db, self.module); let place = { - let subject = place_expr(subject.node_ref(self.db, self.module))?; + let subject = place_expr(subject_node)?; self.expect_place(&subject) }; let subject_ty = @@ -1423,8 +1387,38 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let value_ty = infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); - self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) - .map(|ty| NarrowingConstraints::from_iter([(place, NarrowingConstraint::regular(ty))])) + let mut constraints = self + .evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) + .map(|ty| { + NarrowingConstraints::from_iter([(place, NarrowingConstraint::regular(ty))]) + })?; + + // Narrow tagged unions of `TypedDict`s with `Literal` keys, for example: + // + // class Foo(TypedDict): + // tag: Literal["foo"] + // class Bar(TypedDict): + // tag: Literal["bar"] + // def _(union: Foo | Bar): + // match union["tag"]: + // case "foo": + // reveal_type(union) # Foo + // + // Like in the `if` statement case, we're constraining `union` itself, not `union["tag"]`. + if let ast::Expr::Subscript(subscript) = subject_node { + let inference = infer_expression_types(self.db, subject, TypeContext::default()); + if let Some((place, constraint)) = self.narrow_typeddict_subscript( + inference.expression_type(&*subscript.value), + &subscript.value, + inference.expression_type(&*subscript.slice), + value_ty, + is_positive, + ) { + constraints.insert(place, constraint); + } + } + + Some(constraints) } fn evaluate_match_pattern_or( @@ -1506,6 +1500,73 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } } + + /// Narrow tagged unions of `TypedDict`s with `Literal` keys. + /// + /// Given a subscript expression like `union["tag"]` where `union` is a `TypedDict` (or union + /// containing `TypedDict`s), and a comparison value like `"foo"`, this method creates a + /// constraint on `union` (not `union["tag"]`) that narrows it based on the tag value. + /// + /// Returns `Some((place, constraint))` if narrowing is possible, `None` otherwise. + fn narrow_typeddict_subscript( + &self, + subscript_value_type: Type<'db>, + subscript_value_expr: &ast::Expr, + subscript_key_type: Type<'db>, + rhs_type: Type<'db>, + constrain_with_equality: bool, + ) -> Option<(ScopedPlaceId, NarrowingConstraint<'db>)> { + // Check preconditions: we need a TypedDict, a string key, and a supported tag literal. + if !is_typeddict_or_union_with_typeddicts(self.db, subscript_value_type) { + return None; + } + let subscript_place_expr = place_expr(subscript_value_expr)?; + let Type::StringLiteral(key_literal) = subscript_key_type else { + return None; + }; + if !is_supported_typeddict_tag_literal(rhs_type) { + return None; + } + + // If we have an equality constraint, we have to be careful. If all the matching fields + // in all the `TypedDict`s here have literal types, then yes, equality is as good as a + // type check. However, if any of them are e.g. `int` or `str` or some random class, + // then we can't narrow their type at all, because subclasses of those types can + // implement `__eq__` in any perverse way they like. On the other hand, if this is an + // *inequality* constraint, then we can go ahead and assert "you can't be this exact + // literal type" without worrying about what other types might be present. + if constrain_with_equality + && !all_matching_typeddict_fields_have_literal_types( + self.db, + subscript_value_type, + key_literal.value(self.db), + ) + { + return None; + } + + let field_name = Name::from(key_literal.value(self.db)); + // To avoid excluding non-`TypedDict` types, our constraints are always expressed + // as a negative intersection (i.e. "you're *not* this kind of `TypedDict`"). If + // `constrain_with_equality` is true, the whole constraint is going to be a double + // negative, i.e. "you're *not* a `TypedDict` *without* this literal field". As the + // first step of building that, we negate the right hand side. + let field_type = rhs_type.negate_if(self.db, constrain_with_equality); + // Create the synthesized `TypedDict` with that (possibly negated) field. We don't + // want to constrain the mutability or required-ness of the field, so the most + // compatible form is not-required and read-only. + let field = TypedDictFieldBuilder::new(field_type) + .required(false) + .read_only(true) + .build(); + let schema = TypedDictSchema::from_iter([(field_name, field)]); + let synthesized_typeddict = + TypedDictType::Synthesized(SynthesizedTypedDictType::new(self.db, schema)); + // As mentioned above, the synthesized `TypedDict` is always negated. + let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db); + let place = self.expect_place(&subscript_place_expr); + Some((place, NarrowingConstraint::regular(intersection))) + } } // Return true if the given type is a `TypedDict`, or if it's a union that includes at least one