mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 13:30:49 -05:00
[ty] Narrow TypedDict literal access in match statements (#22299)
## Summary Closes https://github.com/astral-sh/ty/issues/2279.
This commit is contained in:
@@ -2151,6 +2151,82 @@ static_assert(is_assignable_to(FooBar, Bar))
|
||||
TODO: The narrowing that we didn't do above will become possible when we add support for
|
||||
`closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature.
|
||||
|
||||
## Narrowing tagged unions of `TypedDict`s with `match` statements
|
||||
|
||||
Just like with `if` statements, we can narrow tagged unions of `TypedDict`s in `match` statements:
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.10"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import TypedDict, Literal
|
||||
|
||||
class Foo(TypedDict):
|
||||
tag: Literal["foo"]
|
||||
|
||||
class Bar(TypedDict):
|
||||
tag: Literal[42]
|
||||
|
||||
class Baz(TypedDict):
|
||||
tag: Literal[b"baz"]
|
||||
|
||||
class Bing(TypedDict):
|
||||
tag: Literal["bing"]
|
||||
|
||||
def match_statements(u: Foo | Bar | Baz | Bing):
|
||||
match u["tag"]:
|
||||
case "foo":
|
||||
reveal_type(u) # revealed: Foo
|
||||
case 42:
|
||||
reveal_type(u) # revealed: Bar
|
||||
case b"baz":
|
||||
reveal_type(u) # revealed: Baz
|
||||
case _:
|
||||
reveal_type(u) # revealed: Bing
|
||||
```
|
||||
|
||||
We can also narrow a single `TypedDict` type to `Never`:
|
||||
|
||||
```py
|
||||
def match_single(u: Foo):
|
||||
match u["tag"]:
|
||||
case "foo":
|
||||
reveal_type(u) # revealed: Foo
|
||||
case _:
|
||||
reveal_type(u) # revealed: Never
|
||||
```
|
||||
|
||||
Narrowing is restricted to `Literal` tags:
|
||||
|
||||
```py
|
||||
from ty_extensions import is_assignable_to, static_assert
|
||||
|
||||
class NonLiteralTD(TypedDict):
|
||||
tag: int
|
||||
|
||||
def match_non_literal(u: Foo | NonLiteralTD):
|
||||
match u["tag"]:
|
||||
case "foo":
|
||||
# We can't narrow the union here...
|
||||
reveal_type(u) # revealed: Foo | NonLiteralTD
|
||||
case _:
|
||||
# ...(but we *can* narrow here)...
|
||||
reveal_type(u) # revealed: NonLiteralTD
|
||||
```
|
||||
|
||||
We can still narrow `Literal` tags even when non-`TypedDict` types are present in the union:
|
||||
|
||||
```py
|
||||
def match_with_dict(u: Foo | Bar | dict):
|
||||
match u["tag"]:
|
||||
case "foo":
|
||||
# TODO: `dict & ~<TypedDict ...>` should simplify to `dict` here, but that's currently a
|
||||
# false negative in `is_disjoint_impl`.
|
||||
reveal_type(u) # revealed: Foo | (dict[Unknown, Unknown] & ~<TypedDict with items 'tag'>)
|
||||
```
|
||||
|
||||
[closed]: https://peps.python.org/pep-0728/#disallowing-extra-items-explicitly
|
||||
[subtyping section]: https://typing.python.org/en/latest/spec/typeddict.html#subtyping-between-typeddict-types
|
||||
[`typeddict`]: https://typing.python.org/en/latest/spec/typeddict.html
|
||||
|
||||
@@ -1078,55 +1078,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
// Instead, we're going to constrain `my_typeddict_union` itself.
|
||||
if matches!(&**ops, [ast::CmpOp::Eq | ast::CmpOp::NotEq])
|
||||
&& let ast::Expr::Subscript(subscript) = &**left
|
||||
&& let lhs_value_type = inference.expression_type(&*subscript.value)
|
||||
// Checking for `TypedDict`s up front isn't strictly necessary, since the intersection
|
||||
// we're going to build is compatible with non-`TypedDict` types, but we don't want to
|
||||
// do the work to build it and intersect it (or for that matter, let the user see it)
|
||||
// in the common case where there are no `TypedDict`s.
|
||||
&& is_typeddict_or_union_with_typeddicts(self.db, lhs_value_type)
|
||||
&& let Some(subscript_place_expr) = place_expr(&subscript.value)
|
||||
&& let Type::StringLiteral(key_literal) = inference.expression_type(&*subscript.slice)
|
||||
&& let rhs_type = inference.expression_type(&comparators[0])
|
||||
&& is_supported_typeddict_tag_literal(rhs_type)
|
||||
{
|
||||
// If we have an equality constraint (either `==` on the `if` side, or `!=` on the
|
||||
// `else` side), we have to be careful. If all the matching fields in all the
|
||||
// `TypedDict`s here have literal types, then yes, equality is as good as a type check.
|
||||
// However, if any of them are e.g. `int` or `str` or some random class, then we can't
|
||||
// narrow their type at all, because subclasses of those types can implement `__eq__`
|
||||
// in any perverse way they like. On the other hand, if this is an *inequality*
|
||||
// constraint, then we can go ahead and assert "you can't be this exact literal type"
|
||||
// without worrying about what other types might be present.
|
||||
// For `==`, we use equality semantics on the `if` branch (is_positive=true).
|
||||
// For `!=`, we use equality semantics on the `else` branch (is_positive=false).
|
||||
let constrain_with_equality = is_positive == (ops[0] == ast::CmpOp::Eq);
|
||||
if !constrain_with_equality
|
||||
|| all_matching_typeddict_fields_have_literal_types(
|
||||
self.db,
|
||||
lhs_value_type,
|
||||
key_literal.value(self.db),
|
||||
)
|
||||
{
|
||||
let field_name = Name::from(key_literal.value(self.db));
|
||||
let rhs_type = inference.expression_type(&comparators[0]);
|
||||
// To avoid excluding non-`TypedDict` types, our constraints are always expressed
|
||||
// as a negative intersection (i.e. "you're *not* this kind of `TypedDict`"). If
|
||||
// `constrain_with_equality` is true, the whole constraint is going to be a double
|
||||
// negative, i.e. "you're *not* a `TypedDict` *without* this literal field". As the
|
||||
// first step of building that, we negate the right hand side.
|
||||
let field_type = rhs_type.negate_if(self.db, constrain_with_equality);
|
||||
// Create the synthesized `TypedDict` with that (possibly negated) field. We don't
|
||||
// want to constrain the mutability or required-ness of the field, so the most
|
||||
// compatible form is not-required and read-only.
|
||||
let field = TypedDictFieldBuilder::new(field_type)
|
||||
.required(false)
|
||||
.read_only(true)
|
||||
.build();
|
||||
let schema = TypedDictSchema::from_iter([(field_name, field)]);
|
||||
let synthesized_typeddict =
|
||||
TypedDictType::Synthesized(SynthesizedTypedDictType::new(self.db, schema));
|
||||
// As mentioned above, the synthesized `TypedDict` is always negated.
|
||||
let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db);
|
||||
let place = self.expect_place(&subscript_place_expr);
|
||||
constraints.insert(place, NarrowingConstraint::regular(intersection));
|
||||
if let Some((place, constraint)) = self.narrow_typeddict_subscript(
|
||||
inference.expression_type(&*subscript.value),
|
||||
&subscript.value,
|
||||
inference.expression_type(&*subscript.slice),
|
||||
inference.expression_type(&comparators[0]),
|
||||
constrain_with_equality,
|
||||
) {
|
||||
constraints.insert(place, constraint);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1413,8 +1376,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
value: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let subject_node = subject.node_ref(self.db, self.module);
|
||||
let place = {
|
||||
let subject = place_expr(subject.node_ref(self.db, self.module))?;
|
||||
let subject = place_expr(subject_node)?;
|
||||
self.expect_place(&subject)
|
||||
};
|
||||
let subject_ty =
|
||||
@@ -1423,8 +1387,38 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
let value_ty =
|
||||
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
|
||||
|
||||
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
|
||||
.map(|ty| NarrowingConstraints::from_iter([(place, NarrowingConstraint::regular(ty))]))
|
||||
let mut constraints = self
|
||||
.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
|
||||
.map(|ty| {
|
||||
NarrowingConstraints::from_iter([(place, NarrowingConstraint::regular(ty))])
|
||||
})?;
|
||||
|
||||
// Narrow tagged unions of `TypedDict`s with `Literal` keys, for example:
|
||||
//
|
||||
// class Foo(TypedDict):
|
||||
// tag: Literal["foo"]
|
||||
// class Bar(TypedDict):
|
||||
// tag: Literal["bar"]
|
||||
// def _(union: Foo | Bar):
|
||||
// match union["tag"]:
|
||||
// case "foo":
|
||||
// reveal_type(union) # Foo
|
||||
//
|
||||
// Like in the `if` statement case, we're constraining `union` itself, not `union["tag"]`.
|
||||
if let ast::Expr::Subscript(subscript) = subject_node {
|
||||
let inference = infer_expression_types(self.db, subject, TypeContext::default());
|
||||
if let Some((place, constraint)) = self.narrow_typeddict_subscript(
|
||||
inference.expression_type(&*subscript.value),
|
||||
&subscript.value,
|
||||
inference.expression_type(&*subscript.slice),
|
||||
value_ty,
|
||||
is_positive,
|
||||
) {
|
||||
constraints.insert(place, constraint);
|
||||
}
|
||||
}
|
||||
|
||||
Some(constraints)
|
||||
}
|
||||
|
||||
fn evaluate_match_pattern_or(
|
||||
@@ -1506,6 +1500,73 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Narrow tagged unions of `TypedDict`s with `Literal` keys.
|
||||
///
|
||||
/// Given a subscript expression like `union["tag"]` where `union` is a `TypedDict` (or union
|
||||
/// containing `TypedDict`s), and a comparison value like `"foo"`, this method creates a
|
||||
/// constraint on `union` (not `union["tag"]`) that narrows it based on the tag value.
|
||||
///
|
||||
/// Returns `Some((place, constraint))` if narrowing is possible, `None` otherwise.
|
||||
fn narrow_typeddict_subscript(
|
||||
&self,
|
||||
subscript_value_type: Type<'db>,
|
||||
subscript_value_expr: &ast::Expr,
|
||||
subscript_key_type: Type<'db>,
|
||||
rhs_type: Type<'db>,
|
||||
constrain_with_equality: bool,
|
||||
) -> Option<(ScopedPlaceId, NarrowingConstraint<'db>)> {
|
||||
// Check preconditions: we need a TypedDict, a string key, and a supported tag literal.
|
||||
if !is_typeddict_or_union_with_typeddicts(self.db, subscript_value_type) {
|
||||
return None;
|
||||
}
|
||||
let subscript_place_expr = place_expr(subscript_value_expr)?;
|
||||
let Type::StringLiteral(key_literal) = subscript_key_type else {
|
||||
return None;
|
||||
};
|
||||
if !is_supported_typeddict_tag_literal(rhs_type) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// If we have an equality constraint, we have to be careful. If all the matching fields
|
||||
// in all the `TypedDict`s here have literal types, then yes, equality is as good as a
|
||||
// type check. However, if any of them are e.g. `int` or `str` or some random class,
|
||||
// then we can't narrow their type at all, because subclasses of those types can
|
||||
// implement `__eq__` in any perverse way they like. On the other hand, if this is an
|
||||
// *inequality* constraint, then we can go ahead and assert "you can't be this exact
|
||||
// literal type" without worrying about what other types might be present.
|
||||
if constrain_with_equality
|
||||
&& !all_matching_typeddict_fields_have_literal_types(
|
||||
self.db,
|
||||
subscript_value_type,
|
||||
key_literal.value(self.db),
|
||||
)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let field_name = Name::from(key_literal.value(self.db));
|
||||
// To avoid excluding non-`TypedDict` types, our constraints are always expressed
|
||||
// as a negative intersection (i.e. "you're *not* this kind of `TypedDict`"). If
|
||||
// `constrain_with_equality` is true, the whole constraint is going to be a double
|
||||
// negative, i.e. "you're *not* a `TypedDict` *without* this literal field". As the
|
||||
// first step of building that, we negate the right hand side.
|
||||
let field_type = rhs_type.negate_if(self.db, constrain_with_equality);
|
||||
// Create the synthesized `TypedDict` with that (possibly negated) field. We don't
|
||||
// want to constrain the mutability or required-ness of the field, so the most
|
||||
// compatible form is not-required and read-only.
|
||||
let field = TypedDictFieldBuilder::new(field_type)
|
||||
.required(false)
|
||||
.read_only(true)
|
||||
.build();
|
||||
let schema = TypedDictSchema::from_iter([(field_name, field)]);
|
||||
let synthesized_typeddict =
|
||||
TypedDictType::Synthesized(SynthesizedTypedDictType::new(self.db, schema));
|
||||
// As mentioned above, the synthesized `TypedDict` is always negated.
|
||||
let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db);
|
||||
let place = self.expect_place(&subscript_place_expr);
|
||||
Some((place, NarrowingConstraint::regular(intersection)))
|
||||
}
|
||||
}
|
||||
|
||||
// Return true if the given type is a `TypedDict`, or if it's a union that includes at least one
|
||||
|
||||
Reference in New Issue
Block a user