mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 13:30:49 -05:00
[ty] narrow TypedDict unions with not in (#22349)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
@@ -2124,20 +2124,26 @@ shows up in a subset of the union members) is present, but that isn't generally
|
||||
field, it could be *assigned to* with another `TypedDict` that does:
|
||||
|
||||
```py
|
||||
from typing_extensions import Literal
|
||||
|
||||
class Foo(TypedDict):
|
||||
foo: int
|
||||
|
||||
class Bar(TypedDict):
|
||||
bar: int
|
||||
|
||||
def disappointment(u: Foo | Bar):
|
||||
def disappointment(u: Foo | Bar, v: Literal["foo"]):
|
||||
if "foo" in u:
|
||||
# We can't narrow the union here...
|
||||
reveal_type(u) # revealed: Foo | Bar
|
||||
else:
|
||||
# ...(even though we *can* narrow it here)...
|
||||
# TODO: This should narrow to `Bar`, because "foo" is required in `Foo`.
|
||||
reveal_type(u) # revealed: Bar
|
||||
|
||||
if v in u:
|
||||
reveal_type(u) # revealed: Foo | Bar
|
||||
else:
|
||||
reveal_type(u) # revealed: Bar
|
||||
|
||||
# ...because `u` could turn out to be one of these.
|
||||
class FooBar(TypedDict):
|
||||
@@ -2148,6 +2154,39 @@ static_assert(is_assignable_to(FooBar, Foo))
|
||||
static_assert(is_assignable_to(FooBar, Bar))
|
||||
```
|
||||
|
||||
`not in` works in the opposite way to `in`: we can narrow in the positive case, but we cannot narrow
|
||||
in the negative case. The following snippet also tests our narrowing behaviour for intersections
|
||||
that contain `TypedDict`s, and unions that contain intersections that contain `TypedDict`s:
|
||||
|
||||
```py
|
||||
from typing_extensions import Literal, Any
|
||||
from ty_extensions import Intersection, is_assignable_to, static_assert
|
||||
|
||||
def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Literal["bar"]):
|
||||
reveal_type(u) # revealed: Foo | (Bar & Any)
|
||||
reveal_type(v) # revealed: Bar & Any
|
||||
|
||||
if "bar" not in t:
|
||||
reveal_type(t) # revealed: Never
|
||||
else:
|
||||
reveal_type(t) # revealed: Bar
|
||||
|
||||
if "bar" not in u:
|
||||
reveal_type(u) # revealed: Foo
|
||||
else:
|
||||
reveal_type(u) # revealed: Foo | (Bar & Any)
|
||||
|
||||
if "bar" not in v:
|
||||
reveal_type(v) # revealed: Never
|
||||
else:
|
||||
reveal_type(v) # revealed: Bar & Any
|
||||
|
||||
if w not in u:
|
||||
reveal_type(u) # revealed: Foo
|
||||
else:
|
||||
reveal_type(u) # revealed: Foo | (Bar & Any)
|
||||
```
|
||||
|
||||
TODO: The narrowing that we didn't do above will become possible when we add support for
|
||||
`closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata};
|
||||
use crate::types::function::KnownFunction;
|
||||
use crate::types::infer::{ExpressionInference, infer_same_file_expression_type};
|
||||
use crate::types::typed_dict::{
|
||||
SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
|
||||
SynthesizedTypedDictType, TypedDictField, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
|
||||
};
|
||||
use crate::types::{
|
||||
CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass,
|
||||
@@ -1099,6 +1099,75 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
}
|
||||
}
|
||||
|
||||
// Narrow unions and intersections of `TypedDict` in cases where required keys are
|
||||
// excluded:
|
||||
//
|
||||
// class Foo(TypedDict):
|
||||
// foo: int
|
||||
// class Bar(TypedDict):
|
||||
// bar: int
|
||||
//
|
||||
// def _(u: Foo | Bar):
|
||||
// if "foo" not in u:
|
||||
// reveal_type(u) # revealed: Bar
|
||||
if matches!(&**ops, [ast::CmpOp::In | ast::CmpOp::NotIn])
|
||||
&& let Type::StringLiteral(key) = inference.expression_type(&**left)
|
||||
&& let Some(rhs_place_expr) = place_expr(&comparators[0])
|
||||
&& let rhs_type = inference.expression_type(&comparators[0])
|
||||
&& is_typeddict_or_union_with_typeddicts(self.db, rhs_type)
|
||||
{
|
||||
let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn);
|
||||
if is_negative_check {
|
||||
let requires_key = |td: TypedDictType<'db>| -> bool {
|
||||
td.items(self.db)
|
||||
.get(key.value(self.db))
|
||||
.is_some_and(TypedDictField::is_required)
|
||||
};
|
||||
|
||||
let narrowed = match rhs_type {
|
||||
Type::TypedDict(td) => {
|
||||
if requires_key(td) {
|
||||
Type::Never
|
||||
} else {
|
||||
rhs_type
|
||||
}
|
||||
}
|
||||
Type::Intersection(intersection) => {
|
||||
if intersection
|
||||
.positive(self.db)
|
||||
.iter()
|
||||
.copied()
|
||||
.filter_map(Type::as_typed_dict)
|
||||
.any(requires_key)
|
||||
{
|
||||
Type::Never
|
||||
} else {
|
||||
rhs_type
|
||||
}
|
||||
}
|
||||
Type::Union(union) => {
|
||||
// remove all members of the union that would require the key
|
||||
union.filter(self.db, |ty| match ty {
|
||||
Type::TypedDict(td) => !requires_key(*td),
|
||||
Type::Intersection(intersection) => !intersection
|
||||
.positive(self.db)
|
||||
.iter()
|
||||
.copied()
|
||||
.filter_map(Type::as_typed_dict)
|
||||
.any(requires_key),
|
||||
_ => true,
|
||||
})
|
||||
}
|
||||
_ => rhs_type,
|
||||
};
|
||||
|
||||
if narrowed != rhs_type {
|
||||
let place = self.expect_place(&rhs_place_expr);
|
||||
constraints.insert(place, NarrowingConstraint::typeguard(narrowed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_rhs_ty: Option<Type> = None;
|
||||
|
||||
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||
@@ -1677,18 +1746,13 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
|
||||
match ty {
|
||||
Type::TypedDict(_) => true,
|
||||
Type::Union(union) => {
|
||||
union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|union_member_ty| match union_member_ty {
|
||||
Type::TypedDict(_) => true,
|
||||
Type::Intersection(intersection) => {
|
||||
intersection.positive(db).iter().any(Type::is_typed_dict)
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
Type::Intersection(intersection) => {
|
||||
intersection.positive(db).iter().any(Type::is_typed_dict)
|
||||
}
|
||||
Type::Union(union) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|union_member_ty| is_typeddict_or_union_with_typeddicts(db, *union_member_ty)),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user