[ty] implement typing.TypeGuard (#20974)

## Summary

Resolve(s) astral-sh/ty#117, astral-sh/ty#1569

Implement `typing.TypeGuard`. Due to the fact that it [overrides
anything previously known about the checked
value](https://typing.python.org/en/latest/spec/narrowing.html#typeguard)---

> When a conditional statement includes a call to a user-defined type
guard function, and that function returns true, the expression passed as
the first positional argument to the type guard function should be
assumed by a static type checker to take on the type specified in the
TypeGuard return type, unless and until it is further narrowed within
the conditional code block.

---we have to substantially rework the constraints system. In
particular, we make constraints represented as a disjunctive normal form
(DNF) where each term includes a regular constraint, and one or more
disjuncts with a typeguard constraint. Some test cases (including some
with more complex boolean logic) are added to `type_guards.md`.


## Test Plan

- update existing tests
- add new tests for more complex boolean logic with `TypeGuard`
- add new tests for `TypeGuard` variance

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Eric Mark Martin
2025-12-29 20:54:17 -05:00
committed by GitHub
parent 9dadf2724c
commit 8716b4e230
19 changed files with 802 additions and 185 deletions

View File

@@ -356,6 +356,7 @@ impl<'db> Completion<'db> {
Type::IntLiteral(_)
| Type::BooleanLiteral(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::StringLiteral(_)
| Type::LiteralString
| Type::BytesLiteral(_) => CompletionKind::Value,

View File

@@ -16,7 +16,6 @@ def f(*args: Unpack[Ts]) -> tuple[Unpack[Ts]]:
reveal_type(args) # revealed: tuple[@Todo(`Unpack[]` special form), ...]
return args
def g() -> TypeGuard[int]: ...
def i(callback: Callable[Concatenate[int, P], R_co], *args: P.args, **kwargs: P.kwargs) -> R_co:
reveal_type(args) # revealed: P@i.args
reveal_type(kwargs) # revealed: P@i.kwargs

View File

@@ -790,6 +790,44 @@ static_assert(not is_assignable_to(C[B], C[A]))
static_assert(not is_assignable_to(C[A], C[B]))
```
## TypeGuard
`TypeGuard[T]` is covariant in `T`. The typing spec doesn't explicitly call this out, but it follows
from similar logic to invariance of `TypeIs` except without the negative case.
Formally, suppose we have types `A` and `B` with `B < A`. Take `x: object` to be the value that all
subsequent `TypeGuard`s are narrowing.
We can assign `p: TypeGuard[A] = q` where `q: TypeGuard[B]` because
- if `q` is `False`, then no constraints were learned on `x` before and none are now learned, so
nothing changes
- if `q` is `True`, then we know `x: B`. From `B < A`, we conclude `x: A`.
We _cannot_ assign `p: TypeGuard[B] = q` where `q: TypeGuard[A]` because if `q` is `True`, we would
be concluding `x: B` from `x: A`, which is an unsafe downcast.
```py
from typing import TypeGuard
from ty_extensions import is_assignable_to, is_subtype_of, static_assert
class A:
pass
class B(A):
pass
class C[T]:
def check(x: object) -> TypeGuard[T]:
# this is a bad check, but we only care about it type-checking
return False
static_assert(is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
static_assert(is_assignable_to(C[B], C[A]))
static_assert(not is_assignable_to(C[A], C[B]))
```
## Type aliases
The variance of the type alias matches the variance of the value type (RHS type).

View File

@@ -12,21 +12,19 @@ from typing_extensions import TypeGuard, TypeIs
def _(
a: TypeGuard[str],
b: TypeIs[str | int],
c: TypeGuard[Intersection[complex, Not[int], Not[float]]],
c: TypeGuard[bool],
d: TypeIs[tuple[TypeOf[bytes]]],
e: TypeGuard, # error: [invalid-type-form]
f: TypeIs, # error: [invalid-type-form]
):
# TODO: Should be `TypeGuard[str]`
reveal_type(a) # revealed: @Todo(`TypeGuard[]` special form)
reveal_type(a) # revealed: TypeGuard[str]
reveal_type(b) # revealed: TypeIs[str | int]
# TODO: Should be `TypeGuard[complex & ~int & ~float]`
reveal_type(c) # revealed: @Todo(`TypeGuard[]` special form)
reveal_type(c) # revealed: TypeGuard[bool]
reveal_type(d) # revealed: TypeIs[tuple[<class 'bytes'>]]
reveal_type(e) # revealed: Unknown
reveal_type(f) # revealed: Unknown
# TODO: error: [invalid-return-type] "Function always implicitly returns `None`, which is not assignable to return type `TypeGuard[str]`"
# error: [invalid-return-type] "Function always implicitly returns `None`, which is not assignable to return type `TypeGuard[str]`"
def _(a) -> TypeGuard[str]: ...
# error: [invalid-return-type] "Function always implicitly returns `None`, which is not assignable to return type `TypeIs[str]`"
@@ -38,8 +36,7 @@ def g(a) -> TypeIs[str]:
return True
def _(a: object):
# TODO: Should be `TypeGuard[str @ a]`
reveal_type(f(a)) # revealed: @Todo(`TypeGuard[]` special form)
reveal_type(f(a)) # revealed: TypeGuard[str @ a]
reveal_type(g(a)) # revealed: TypeIs[str @ a]
```
@@ -96,6 +93,72 @@ def _(a: int) -> TypeIs[str]: ...
def _(a: bool | str) -> TypeIs[int]: ...
```
## Methods
Methods narrow the first positional argument after `self` or `cls`
```py
from typing import TypeGuard
class C:
def f(self, x: object) -> TypeGuard[str]:
return True
@classmethod
def g(cls, x: object) -> TypeGuard[int]:
return True
# TODO: this could error at definition time
def h(self) -> TypeGuard[str]:
return True
# TODO: this could error at definition time
@classmethod
def j(cls) -> TypeGuard[int]:
return True
def _(x: object):
if C().f(x):
reveal_type(x) # revealed: str
if C.f(C(), x):
# TODO: should be str
reveal_type(x) # revealed: object
if C.g(x):
reveal_type(x) # revealed: int
if C().g(x):
reveal_type(x) # revealed: int
if C().h(): # error: [invalid-type-guard-call] "Type guard call does not have a target"
pass
if C.j(): # error: [invalid-type-guard-call] "Type guard call does not have a target"
pass
```
```py
from typing_extensions import TypeIs
def is_int(val: object) -> TypeIs[int]:
return isinstance(val, int)
class A:
def is_int(self, val: object) -> TypeIs[int]:
return isinstance(val, int)
@classmethod
def is_int2(cls, val: object) -> TypeIs[int]:
return isinstance(val, int)
def _(x: object):
if is_int(x):
reveal_type(x) # revealed: int
if A().is_int(x):
reveal_type(x) # revealed: int
if A().is_int2(x):
reveal_type(x) # revealed: int
if A.is_int2(x):
reveal_type(x) # revealed: int
```
## Arguments to special forms
`TypeGuard` and `TypeIs` accept exactly one type argument.
@@ -105,15 +168,14 @@ from typing_extensions import TypeGuard, TypeIs
a = 123
# TODO: error: [invalid-type-form]
# error: [invalid-type-form] "Special form `typing.TypeGuard` expected exactly one type parameter"
def f(_) -> TypeGuard[int, str]: ...
# error: [invalid-type-form] "Special form `typing.TypeIs` expected exactly one type parameter"
# error: [invalid-type-form] "Variable of type `Literal[123]` is not allowed in a type expression"
def g(_) -> TypeIs[a, str]: ...
# TODO: Should be `Unknown`
reveal_type(f(0)) # revealed: @Todo(`TypeGuard[]` special form)
reveal_type(f(0)) # revealed: Unknown
reveal_type(g(0)) # revealed: Unknown
```
@@ -126,9 +188,10 @@ from typing_extensions import Literal, TypeGuard, TypeIs, assert_never
def _(a: object, flag: bool) -> TypeGuard[str]:
if flag:
# error: [invalid-return-type] "Return type does not match returned value: expected `TypeGuard[str]`, found `Literal[0]`"
return 0
# TODO: error: [invalid-return-type] "Return type does not match returned value: expected `TypeIs[str]`, found `Literal["foo"]`"
# error: [invalid-return-type] "Return type does not match returned value: expected `TypeGuard[str]`, found `Literal["foo"]`"
return "foo"
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `TypeIs[str]`"
@@ -193,8 +256,7 @@ def is_bar(a: object) -> TypeIs[Bar]:
def _(a: Foo | Bar):
if guard_foo(a):
# TODO: Should be `Foo`
reveal_type(a) # revealed: Foo | Bar
reveal_type(a) # revealed: Foo
else:
reveal_type(a) # revealed: Foo | Bar
@@ -204,6 +266,26 @@ def _(a: Foo | Bar):
reveal_type(a) # revealed: Foo & ~Bar
```
```py
from typing import TypeGuard, reveal_type
class P:
pass
class A:
pass
class B:
pass
def is_b(val: object) -> TypeGuard[B]:
return isinstance(val, B)
def _(x: P):
if isinstance(x, A) or is_b(x):
reveal_type(x) # revealed: B | (P & A)
```
Attribute and subscript narrowing is supported:
```py
@@ -215,23 +297,17 @@ class C(Generic[T]):
v: T
def _(a: tuple[Foo, Bar] | tuple[Bar, Foo], c: C[Any]):
# TODO: Should be `TypeGuard[Foo @ a[1]]`
if reveal_type(guard_foo(a[1])): # revealed: @Todo(`TypeGuard[]` special form)
# TODO: Should be `tuple[Bar, Foo]`
if reveal_type(guard_foo(a[1])): # revealed: TypeGuard[Foo @ a[1]]
reveal_type(a) # revealed: tuple[Foo, Bar] | tuple[Bar, Foo]
# TODO: Should be `Foo`
reveal_type(a[1]) # revealed: Bar | Foo
reveal_type(a[1]) # revealed: Foo
if reveal_type(is_bar(a[0])): # revealed: TypeIs[Bar @ a[0]]
# TODO: Should be `tuple[Bar, Bar & Foo]`
reveal_type(a) # revealed: tuple[Foo, Bar] | tuple[Bar, Foo]
reveal_type(a[0]) # revealed: Bar
# TODO: Should be `TypeGuard[Foo @ c.v]`
if reveal_type(guard_foo(c.v)): # revealed: @Todo(`TypeGuard[]` special form)
if reveal_type(guard_foo(c.v)): # revealed: TypeGuard[Foo @ c.v]
reveal_type(c) # revealed: C[Any]
# TODO: Should be `Foo`
reveal_type(c.v) # revealed: Any
reveal_type(c.v) # revealed: Foo
if reveal_type(is_bar(c.v)): # revealed: TypeIs[Bar @ c.v]
reveal_type(c) # revealed: C[Any]
@@ -246,8 +322,7 @@ def _(a: Foo | Bar):
c = is_bar(a)
reveal_type(a) # revealed: Foo | Bar
# TODO: Should be `TypeGuard[Foo @ a]`
reveal_type(b) # revealed: @Todo(`TypeGuard[]` special form)
reveal_type(b) # revealed: TypeGuard[Foo @ a]
reveal_type(c) # revealed: TypeIs[Bar @ a]
if b:
@@ -345,25 +420,82 @@ class Baz(Bar): ...
def guard_foo(a: object) -> TypeGuard[Foo]:
return True
def guard_bar(a: object) -> TypeGuard[Bar]:
return True
def is_bar(a: object) -> TypeIs[Bar]:
return True
def does_not_narrow_in_negative_case(a: Foo | Bar):
if not guard_foo(a):
# TODO: Should be `Bar`
reveal_type(a) # revealed: Foo | Bar
else:
reveal_type(a) # revealed: Foo | Bar
reveal_type(a) # revealed: Foo
def narrowed_type_must_be_exact(a: object, b: Baz):
if guard_foo(b):
# TODO: Should be `Foo`
reveal_type(b) # revealed: Baz
reveal_type(b) # revealed: Foo
if isinstance(a, Baz) and is_bar(a):
reveal_type(a) # revealed: Baz
if isinstance(a, Bar) and guard_foo(a):
# TODO: Should be `Foo`
reveal_type(a) # revealed: Bar
reveal_type(a) # revealed: Foo
if guard_bar(a):
reveal_type(a) # revealed: Bar
```
## TypeGuard overrides normal constraints
TypeGuard constraints override any previous narrowing, but additional "regular" constraints can be
added on to TypeGuard constraints.
```py
from typing_extensions import TypeGuard, TypeIs
class A: ...
class B: ...
class C: ...
def f(x: object) -> TypeGuard[A]:
return True
def g(x: object) -> TypeGuard[B]:
return True
def h(x: object) -> TypeIs[C]:
return True
def _(x: object):
if f(x) and g(x) and h(x):
reveal_type(x) # revealed: B & C
```
## Boolean logic with TypeGuard and TypeIs
TypeGuard constraints need to properly distribute through boolean operations.
```py
from typing_extensions import TypeGuard, TypeIs
class A: ...
class B: ...
class C: ...
def f(x: object) -> TypeIs[A]:
return True
def g(x: object) -> TypeGuard[B]:
return True
def h(x: object) -> TypeIs[C]:
return True
def _(x: object):
# g(x) or h(x) should give B | C
# Then f(x) and (...) should distribute: (f(x) and g(x)) or (f(x) and h(x))
# Which is (Regular(A) & TypeGuard(B)) | (Regular(A) & Regular(C))
# TypeGuard clobbers in the first branch, giving: B | (A & C)
if f(x) and (g(x) or h(x)):
reveal_type(x) # revealed: B | (A & C)
```

View File

@@ -1383,8 +1383,7 @@ from typing_extensions import Any, TypeGuard, TypeIs
static_assert(is_assignable_to(TypeGuard[Unknown], bool))
static_assert(is_assignable_to(TypeIs[Any], bool))
# TODO no error
static_assert(not is_assignable_to(TypeGuard[Unknown], str)) # error: [static-assert-error]
static_assert(not is_assignable_to(TypeGuard[Unknown], str))
static_assert(not is_assignable_to(TypeIs[Any], str))
```

View File

@@ -578,8 +578,7 @@ from typing_extensions import TypeGuard, TypeIs
static_assert(not is_disjoint_from(bool, TypeGuard[str]))
static_assert(not is_disjoint_from(bool, TypeIs[str]))
# TODO no error
static_assert(is_disjoint_from(str, TypeGuard[str])) # error: [static-assert-error]
static_assert(is_disjoint_from(str, TypeGuard[str]))
static_assert(is_disjoint_from(str, TypeIs[str]))
```

View File

@@ -670,9 +670,8 @@ Fully-static `TypeGuard[...]` and `TypeIs[...]` are subtypes of `bool`.
from ty_extensions import is_subtype_of, static_assert
from typing_extensions import TypeGuard, TypeIs
# TODO: TypeGuard
# static_assert(is_subtype_of(TypeGuard[int], bool))
# static_assert(is_subtype_of(TypeGuard[int], int))
static_assert(is_subtype_of(TypeGuard[str], bool))
static_assert(is_subtype_of(TypeGuard[str], int))
static_assert(is_subtype_of(TypeIs[str], bool))
static_assert(is_subtype_of(TypeIs[str], int))
```
@@ -683,12 +682,12 @@ static_assert(is_subtype_of(TypeIs[str], int))
from ty_extensions import is_equivalent_to, is_subtype_of, static_assert
from typing_extensions import TypeGuard, TypeIs
# TODO: TypeGuard
# static_assert(is_subtype_of(TypeGuard[int], TypeGuard[int]))
# static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int]))
static_assert(is_subtype_of(TypeGuard[int], TypeGuard[int]))
static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int]))
static_assert(is_subtype_of(TypeIs[int], TypeIs[int]))
static_assert(is_subtype_of(TypeIs[int], TypeIs[int]))
static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int]))
static_assert(not is_subtype_of(TypeGuard[int], TypeGuard[bool]))
static_assert(not is_subtype_of(TypeIs[bool], TypeIs[int]))
static_assert(not is_subtype_of(TypeIs[int], TypeIs[bool]))

View File

@@ -266,7 +266,7 @@ use crate::semantic_index::use_def::place_state::{
LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId,
};
use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex};
use crate::types::{IntersectionBuilder, Truthiness, Type, infer_narrowing_constraint};
use crate::types::{NarrowingConstraint, Truthiness, Type, infer_narrowing_constraint};
mod place_state;
@@ -757,22 +757,22 @@ impl<'db> ConstraintsIterator<'_, 'db> {
base_ty: Type<'db>,
place: ScopedPlaceId,
) -> Type<'db> {
let constraint_tys: Vec<_> = self
.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place))
.collect();
if constraint_tys.is_empty() {
base_ty
} else {
constraint_tys
.into_iter()
.rev()
.fold(
IntersectionBuilder::new(db).add_positive(base_ty),
IntersectionBuilder::add_positive,
)
.build()
}
// Constraints are in reverse-source order. Due to TypeGuard semantics
// constraint AND is non-commutative and so we _must_ apply in
// source order.
//
// Fortunately, constraint AND is still associative, so we can still iterate left-to-right
// and accumulate rightward.
self.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place))
.reduce(|acc, constraint| {
// See above---note the reverse application
constraint.merge_constraint_and(acc, db)
})
.map_or(base_ty, |constraint| {
NarrowingConstraint::regular(base_ty)
.merge_constraint_and(constraint, db)
.evaluate_constraint_type(db)
})
}
}

View File

@@ -65,7 +65,7 @@ use crate::types::generics::{
walk_generic_context,
};
use crate::types::mro::{Mro, MroError, MroIterator};
pub(crate) use crate::types::narrow::infer_narrowing_constraint;
pub(crate) use crate::types::narrow::{NarrowingConstraint, infer_narrowing_constraint};
use crate::types::newtype::NewType;
pub(crate) use crate::types::signatures::{Parameter, Parameters};
use crate::types::signatures::{ParameterForm, walk_signature};
@@ -865,6 +865,8 @@ pub enum Type<'db> {
BoundSuper(BoundSuperType<'db>),
/// A subtype of `bool` that allows narrowing in both positive and negative cases.
TypeIs(TypeIsType<'db>),
/// A subtype of `bool` that allows narrowing in only the positive case.
TypeGuard(TypeGuardType<'db>),
/// A type that represents an inhabitant of a `TypedDict`.
TypedDict(TypedDictType<'db>),
/// An aliased type (lazily not-yet-unpacked to its value type).
@@ -878,6 +880,26 @@ pub enum Type<'db> {
NewTypeInstance(NewType<'db>),
}
/// Helper for `recursive_type_normalized_impl` for `TypeGuardLike` types.
fn recursive_type_normalize_type_guard_like<'db, T: TypeGuardLike<'db>>(
db: &'db dyn Db,
guard: T,
div: Type<'db>,
nested: bool,
) -> Option<Type<'db>> {
let ty = if nested {
guard
.return_type(db)
.recursive_type_normalized_impl(db, div, true)?
} else {
guard
.return_type(db)
.recursive_type_normalized_impl(db, div, true)
.unwrap_or(div)
};
Some(guard.with_type(db, ty))
}
#[salsa::tracked]
impl<'db> Type<'db> {
pub(crate) const fn any() -> Self {
@@ -1618,6 +1640,9 @@ impl<'db> Type<'db> {
Type::TypeIs(type_is) => visitor.visit(self, || {
type_is.with_type(db, type_is.return_type(db).normalized_impl(db, visitor))
}),
Type::TypeGuard(type_guard) => visitor.visit(self, || {
type_guard.with_type(db, type_guard.return_type(db).normalized_impl(db, visitor))
}),
Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized()),
Type::EnumLiteral(enum_literal)
if is_single_member_enum(db, enum_literal.enum_class(db)) =>
@@ -1741,17 +1766,10 @@ impl<'db> Type<'db> {
.recursive_type_normalized_impl(db, div, nested)
.map(Type::KnownInstance),
Type::TypeIs(type_is) => {
let ty = if nested {
type_is
.return_type(db)
.recursive_type_normalized_impl(db, div, true)?
} else {
type_is
.return_type(db)
.recursive_type_normalized_impl(db, div, true)
.unwrap_or(div)
};
Some(type_is.with_type(db, ty))
recursive_type_normalize_type_guard_like(db, type_is, div, nested)
}
Type::TypeGuard(type_guard) => {
recursive_type_normalize_type_guard_like(db, type_guard, div, nested)
}
Type::Dynamic(dynamic) => Some(Type::Dynamic(dynamic.recursive_type_normalized())),
Type::TypedDict(_) => {
@@ -1825,6 +1843,7 @@ impl<'db> Type<'db> {
| Type::TypeVar(_)
| Type::BoundSuper(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::TypeAlias(_)
| Type::NewTypeInstance(_) => false,
@@ -1937,6 +1956,7 @@ impl<'db> Type<'db> {
| Type::LiteralString
| Type::BytesLiteral(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_) => None,
// TODO
@@ -2807,15 +2827,29 @@ impl<'db> Type<'db> {
)
}),
// `TypeIs[T]` is a subtype of `bool`.
(Type::TypeIs(_), _) => KnownClass::Bool.to_instance(db).has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
),
// `TypeGuard` is covariant.
(Type::TypeGuard(left), Type::TypeGuard(right)) => {
left.return_type(db).has_relation_to_impl(
db,
right.return_type(db),
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
}
// `TypeIs[T]` and `TypeGuard[T]` are subtypes of `bool`.
(Type::TypeIs(_) | Type::TypeGuard(_), _) => {
KnownClass::Bool.to_instance(db).has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
}
// Function-like callables are subtypes of `FunctionType`
(Type::Callable(callable), _) if callable.is_function_like(db) => {
@@ -3785,8 +3819,14 @@ impl<'db> Type<'db> {
ConstraintSet::from(!known_instance.is_instance_of(db, instance.class(db)))
}
(Type::BooleanLiteral(..) | Type::TypeIs(_), Type::NominalInstance(instance))
| (Type::NominalInstance(instance), Type::BooleanLiteral(..) | Type::TypeIs(_)) => {
(
Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_),
Type::NominalInstance(instance),
)
| (
Type::NominalInstance(instance),
Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_),
) => {
// A `Type::BooleanLiteral()` must be an instance of exactly `bool`
// (it cannot be an instance of a `bool` subclass)
KnownClass::Bool
@@ -3794,8 +3834,10 @@ impl<'db> Type<'db> {
.negate(db)
}
(Type::BooleanLiteral(..) | Type::TypeIs(_), _)
| (_, Type::BooleanLiteral(..) | Type::TypeIs(_)) => ConstraintSet::from(true),
(Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_), _)
| (_, Type::BooleanLiteral(..) | Type::TypeIs(_) | Type::TypeGuard(_)) => {
ConstraintSet::from(true)
}
(Type::IntLiteral(..), Type::NominalInstance(instance))
| (Type::NominalInstance(instance), Type::IntLiteral(..)) => {
@@ -4261,6 +4303,7 @@ impl<'db> Type<'db> {
}
Type::AlwaysTruthy | Type::AlwaysFalsy => false,
Type::TypeIs(type_is) => type_is.is_bound(db),
Type::TypeGuard(type_guard) => type_guard.is_bound(db),
Type::TypedDict(_) => false,
Type::TypeAlias(alias) => alias.value_type(db).is_singleton(db),
Type::NewTypeInstance(newtype) => newtype.concrete_base_type(db).is_singleton(db),
@@ -4322,6 +4365,7 @@ impl<'db> Type<'db> {
}
Type::TypeIs(type_is) => type_is.is_bound(db),
Type::TypeGuard(type_guard) => type_guard.is_bound(db),
Type::TypeAlias(alias) => alias.value_type(db).is_single_valued(db),
@@ -4478,6 +4522,7 @@ impl<'db> Type<'db> {
| Type::ProtocolInstance(_)
| Type::PropertyInstance(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_) => None,
}
@@ -4606,7 +4651,7 @@ impl<'db> Type<'db> {
}
Type::IntLiteral(_) => KnownClass::Int.to_instance(db).instance_member(db, name),
Type::BooleanLiteral(_) | Type::TypeIs(_) => {
Type::BooleanLiteral(_) | Type::TypeIs(_) | Type::TypeGuard(_) => {
KnownClass::Bool.to_instance(db).instance_member(db, name)
}
Type::StringLiteral(_) | Type::LiteralString => {
@@ -5276,6 +5321,7 @@ impl<'db> Type<'db> {
| Type::AlwaysTruthy
| Type::AlwaysFalsy
| Type::TypeIs(..)
| Type::TypeGuard(..)
| Type::TypedDict(_) => {
let fallback = self.instance_member(db, name_str);
@@ -5603,7 +5649,8 @@ impl<'db> Type<'db> {
| Type::Never
| Type::Callable(_)
| Type::LiteralString
| Type::TypeIs(_) => Truthiness::Ambiguous,
| Type::TypeIs(_)
| Type::TypeGuard(_) => Truthiness::Ambiguous,
Type::TypedDict(_) => {
// TODO: We could do better here, but it's unclear if this is important.
@@ -6547,6 +6594,7 @@ impl<'db> Type<'db> {
| Type::BoundSuper(_)
| Type::ModuleLiteral(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_) => CallableBinding::not_callable(self).into(),
}
}
@@ -6776,6 +6824,7 @@ impl<'db> Type<'db> {
| Type::EnumLiteral(_)
| Type::BoundSuper(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_) => None
}
}
@@ -7390,6 +7439,7 @@ impl<'db> Type<'db> {
| Type::AlwaysTruthy
| Type::AlwaysFalsy
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_) => None,
}
@@ -7446,6 +7496,7 @@ impl<'db> Type<'db> {
| Type::ProtocolInstance(_)
| Type::PropertyInstance(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_) => Err(InvalidTypeExpressionError {
invalid_expressions: smallvec::smallvec_inline![
InvalidTypeExpression::InvalidType(*self, scope_id)
@@ -7735,7 +7786,9 @@ impl<'db> Type<'db> {
Type::SpecialForm(special_form) => special_form.to_meta_type(db),
Type::PropertyInstance(_) => KnownClass::Property.to_class_literal(db),
Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)),
Type::BooleanLiteral(_) | Type::TypeIs(_) => KnownClass::Bool.to_class_literal(db),
Type::BooleanLiteral(_) | Type::TypeIs(_) | Type::TypeGuard(_) => {
KnownClass::Bool.to_class_literal(db)
}
Type::BytesLiteral(_) => KnownClass::Bytes.to_class_literal(db),
Type::IntLiteral(_) => KnownClass::Int.to_class_literal(db),
Type::EnumLiteral(enum_literal) => Type::ClassLiteral(enum_literal.enum_class(db)),
@@ -8036,6 +8089,8 @@ impl<'db> Type<'db> {
// TODO(jelle): Materialize should be handled differently, since TypeIs is invariant
Type::TypeIs(type_is) => type_is.with_type(db, type_is.return_type(db).apply_type_mapping(db, type_mapping, tcx)),
Type::TypeGuard(type_guard) => type_guard.with_type(db, type_guard.return_type(db).apply_type_mapping(db, type_mapping, tcx)),
Type::TypeAlias(alias) => {
if TypeMapping::EagerExpansion == *type_mapping {
return alias.raw_value_type(db).expand_eagerly(db);
@@ -8242,6 +8297,15 @@ impl<'db> Type<'db> {
);
}
Type::TypeGuard(type_guard) => {
type_guard.return_type(db).find_legacy_typevars_impl(
db,
binding_context,
typevars,
visitor,
);
}
Type::TypeAlias(alias) => {
visitor.visit(self, || {
alias.value_type(db).find_legacy_typevars_impl(
@@ -8504,7 +8568,8 @@ impl<'db> Type<'db> {
// These types have no definition
Self::Dynamic(DynamicType::Divergent(_) | DynamicType::Todo(_) | DynamicType::TodoUnpack | DynamicType::TodoStarredExpression)
| Self::Callable(_)
| Self::TypeIs(_) => None,
| Self::TypeIs(_)
| Self::TypeGuard(_) => None,
}
}
@@ -8668,6 +8733,7 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
.collect(),
Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar),
Type::TypeIs(type_is_type) => type_is_type.variance_of(db, typevar),
Type::TypeGuard(type_guard_type) => type_guard_type.variance_of(db, typevar),
Type::KnownInstance(known_instance) => known_instance.variance_of(db, typevar),
Type::Dynamic(_)
| Type::Never
@@ -14667,6 +14733,144 @@ impl<'db> VarianceInferable<'db> for TypeIsType<'db> {
}
}
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
pub struct TypeGuardType<'db> {
return_type: Type<'db>,
/// The ID of the scope to which the place belongs
/// and the ID of the place itself within that scope.
place_info: Option<(ScopeId<'db>, ScopedPlaceId)>,
}
fn walk_typeguard_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
db: &'db dyn Db,
typeguard_type: TypeGuardType<'db>,
visitor: &V,
) {
visitor.visit_type(db, typeguard_type.return_type(db));
}
// The Salsa heap is tracked separately.
impl get_size2::GetSize for TypeGuardType<'_> {}
impl<'db> TypeGuardType<'db> {
pub(crate) fn place_name(self, db: &'db dyn Db) -> Option<String> {
let (scope, place) = self.place_info(db)?;
let table = place_table(db, scope);
Some(format!("{}", table.place(place)))
}
pub(crate) fn unbound(db: &'db dyn Db, ty: Type<'db>) -> Type<'db> {
Type::TypeGuard(Self::new(db, ty, None))
}
pub(crate) fn bound(
db: &'db dyn Db,
return_type: Type<'db>,
scope: ScopeId<'db>,
place: ScopedPlaceId,
) -> Type<'db> {
Type::TypeGuard(Self::new(db, return_type, Some((scope, place))))
}
#[must_use]
pub(crate) fn bind(
self,
db: &'db dyn Db,
scope: ScopeId<'db>,
place: ScopedPlaceId,
) -> Type<'db> {
Self::bound(db, self.return_type(db), scope, place)
}
#[must_use]
pub(crate) fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> {
Type::TypeGuard(Self::new(db, ty, self.place_info(db)))
}
pub(crate) fn is_bound(self, db: &'db dyn Db) -> bool {
self.place_info(db).is_some()
}
}
impl<'db> VarianceInferable<'db> for TypeGuardType<'db> {
// `TypeGuard` is covariant in its type parameter. See the `TypeGuard`
// section of mdtest/generics/pep695/variance.md for details.
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.return_type(db).variance_of(db, typevar)
}
}
/// Common trait for `TypeIs` and `TypeGuard` types that share similar structure
/// but have different semantic behaviors.
pub(crate) trait TypeGuardLike<'db>: Copy {
/// The name of this type guard form (for error messages and display)
const FORM_NAME: &'static str;
/// Get the return type that the type guard narrows to
fn return_type(self, db: &'db dyn Db) -> Type<'db>;
/// Get the place info (scope and place ID) if bound
fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)>;
/// Get the human-readable place name if bound
fn place_name(self, db: &'db dyn Db) -> Option<String>;
/// Create a new instance with a different return type, wrapped in Type
fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db>;
/// The `SpecialFormType` for display purposes
fn special_form() -> SpecialFormType;
}
impl<'db> TypeGuardLike<'db> for TypeIsType<'db> {
const FORM_NAME: &'static str = "TypeIs";
fn return_type(self, db: &'db dyn Db) -> Type<'db> {
TypeIsType::return_type(self, db)
}
fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> {
TypeIsType::place_info(self, db)
}
fn place_name(self, db: &'db dyn Db) -> Option<String> {
TypeIsType::place_name(self, db)
}
fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> {
TypeIsType::with_type(self, db, ty)
}
fn special_form() -> SpecialFormType {
SpecialFormType::TypeIs
}
}
impl<'db> TypeGuardLike<'db> for TypeGuardType<'db> {
const FORM_NAME: &'static str = "TypeGuard";
fn return_type(self, db: &'db dyn Db) -> Type<'db> {
TypeGuardType::return_type(self, db)
}
fn place_info(self, db: &'db dyn Db) -> Option<(ScopeId<'db>, ScopedPlaceId)> {
TypeGuardType::place_info(self, db)
}
fn place_name(self, db: &'db dyn Db) -> Option<String> {
TypeGuardType::place_name(self, db)
}
fn with_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> {
TypeGuardType::with_type(self, db, ty)
}
fn special_form() -> SpecialFormType {
SpecialFormType::TypeGuard
}
}
/// Walk the MRO of this class and return the last class just before the specified known base.
/// This can be used to determine upper bounds for `Self` type variables on methods that are
/// being added to the given class.

View File

@@ -389,7 +389,7 @@ impl<'db> BoundSuperType<'db> {
None => delegate_with_error_mapped(Type::object(), Some(type_var)),
};
}
Type::BooleanLiteral(_) | Type::TypeIs(_) => {
Type::BooleanLiteral(_) | Type::TypeIs(_) | Type::TypeGuard(_) => {
return delegate_to(KnownClass::Bool.to_instance(db));
}
Type::IntLiteral(_) => return delegate_to(KnownClass::Int.to_instance(db)),

View File

@@ -177,6 +177,7 @@ impl<'db> ClassBase<'db> {
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_) => None,
Type::KnownInstance(known_instance) => match known_instance {

View File

@@ -27,8 +27,8 @@ use crate::types::visitor::TypeVisitor;
use crate::types::{
BoundTypeVarIdentity, CallableType, CallableTypeKind, IntersectionType, KnownBoundMethodType,
KnownClass, KnownInstanceType, MaterializationKind, Protocol, ProtocolInstanceType,
SpecialFormType, StringLiteralType, SubclassOfInner, Type, TypedDictType, UnionType,
WrapperDescriptorKind, visitor,
SpecialFormType, StringLiteralType, SubclassOfInner, Type, TypeGuardLike, TypedDictType,
UnionType, WrapperDescriptorKind, visitor,
};
/// Settings for displaying types and signatures
@@ -590,6 +590,28 @@ impl Display for ClassDisplay<'_> {
}
}
/// Helper for displaying `TypeGuardLike` types `TypeIs` and `TypeGuard`.
fn fmt_type_guard_like<'db, T: TypeGuardLike<'db>>(
db: &'db dyn Db,
guard: T,
settings: &DisplaySettings<'db>,
f: &mut TypeWriter<'_, '_, 'db>,
) -> fmt::Result {
f.with_type(Type::SpecialForm(T::special_form()))
.write_str(T::FORM_NAME)?;
f.write_char('[')?;
guard
.return_type(db)
.display_with(db, settings.singleline())
.fmt_detailed(f)?;
if let Some(name) = guard.place_name(db) {
f.set_invalid_type_annotation();
f.write_str(" @ ")?;
f.write_str(&name)?;
}
f.write_str("]")
}
/// Writes the string representation of a type, which is the value displayed either as
/// `Literal[<repr>]` or `Literal[<repr1>, <repr2>]` for literal types or as `<repr>` for
/// non literals
@@ -971,20 +993,9 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> {
.fmt_detailed(f)?;
f.write_str(">")
}
Type::TypeIs(type_is) => {
f.with_type(Type::SpecialForm(SpecialFormType::TypeIs))
.write_str("TypeIs")?;
f.write_char('[')?;
type_is
.return_type(self.db)
.display_with(self.db, self.settings.singleline())
.fmt_detailed(f)?;
if let Some(name) = type_is.place_name(self.db) {
f.set_invalid_type_annotation();
f.write_str(" @ ")?;
f.write_str(&name)?;
}
f.write_str("]")
Type::TypeIs(type_is) => fmt_type_guard_like(self.db, type_is, &self.settings, f),
Type::TypeGuard(type_guard) => {
fmt_type_guard_like(self.db, type_guard, &self.settings, f)
}
Type::TypedDict(TypedDictType::Class(defining_class)) => match defining_class {
ClassType::NonGeneric(class) => class

View File

@@ -1303,6 +1303,7 @@ fn is_instance_truthiness<'db>(
| Type::AlwaysFalsy
| Type::BoundSuper(..)
| Type::TypeIs(..)
| Type::TypeGuard(..)
| Type::Callable(..)
| Type::Dynamic(..)
| Type::Never

View File

@@ -2059,7 +2059,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let declared_ty = self.file_expression_type(returns);
let expected_ty = match declared_ty {
Type::TypeIs(_) => KnownClass::Bool.to_instance(self.db()),
Type::TypeIs(_) | Type::TypeGuard(_) => KnownClass::Bool.to_instance(self.db()),
ty => ty,
};
@@ -4665,6 +4665,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
| Type::AlwaysTruthy
| Type::AlwaysFalsy
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_) => {
// TODO: We could use the annotated parameter type of `__setattr__` as type context here.
@@ -8939,11 +8940,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
match return_ty {
// TODO: TypeGuard
Type::TypeIs(type_is) => match find_narrowed_place() {
Some(place) => type_is.bind(db, scope, place),
None => return_ty,
},
Type::TypeGuard(type_guard) => match find_narrowed_place() {
Some(place) => type_guard.bind(db, scope, place),
None => return_ty,
},
_ => return_ty,
}
}
@@ -10039,6 +10043,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
| Type::BoundSuper(_)
| Type::TypeVar(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_),
) => {
@@ -10539,6 +10544,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
| Type::BoundSuper(_)
| Type::TypeVar(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_),
Type::FunctionLiteral(_)
@@ -10569,6 +10575,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
| Type::BoundSuper(_)
| Type::TypeVar(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_),
op,

View File

@@ -16,8 +16,8 @@ use crate::types::tuple::{TupleSpecBuilder, TupleType};
use crate::types::{
BindingContext, CallableType, DynamicType, GenericContext, IntersectionBuilder, KnownClass,
KnownInstanceType, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, SubclassOfType,
Type, TypeAliasType, TypeContext, TypeIsType, TypeMapping, TypeVarKind, UnionBuilder,
UnionType, any_over_type, todo_type,
Type, TypeAliasType, TypeContext, TypeGuardType, TypeIsType, TypeMapping, TypeVarKind,
UnionBuilder, UnionType, any_over_type, todo_type,
};
/// Type expressions
@@ -1521,10 +1521,26 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
.top_materialization(self.db()),
),
},
SpecialFormType::TypeGuard => {
self.infer_type_expression(arguments_slice);
todo_type!("`TypeGuard[]` special form")
}
SpecialFormType::TypeGuard => match arguments_slice {
ast::Expr::Tuple(_) => {
self.infer_type_expression(arguments_slice);
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
let diag = builder.into_diagnostic(format_args!(
"Special form `typing.TypeGuard` expected exactly one type parameter",
));
diagnostic::add_type_expression_reference_link(diag);
}
Type::unknown()
}
_ => TypeGuardType::unbound(
self.db(),
// Unlike `TypeIs`, don't use top materialization, because
// `TypeGuard` clobbering behavior makes it counterintuitive
self.infer_type_expression(arguments_slice),
),
},
SpecialFormType::Concatenate => {
let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice {
&*tuple.elts

View File

@@ -282,7 +282,8 @@ impl<'db> AllMembers<'db> {
| Type::SpecialForm(_)
| Type::KnownInstance(_)
| Type::BoundSuper(_)
| Type::TypeIs(_) => match ty.to_meta_type(db) {
| Type::TypeIs(_)
| Type::TypeGuard(_) => match ty.to_meta_type(db) {
Type::ClassLiteral(class_literal) => {
self.extend_with_class_members(db, ty, class_literal);
}

View File

@@ -10,7 +10,7 @@ use crate::semantic_index::scope::ScopeId;
use crate::subscript::PyIndex;
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::infer::{ExpressionInference, infer_same_file_expression_type};
use crate::types::typed_dict::{
SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
};
@@ -29,6 +29,7 @@ use itertools::Itertools;
use ruff_python_ast as ast;
use ruff_python_ast::{BoolOp, ExprBoolOp};
use rustc_hash::FxHashMap;
use smallvec::{SmallVec, smallvec};
use std::collections::hash_map::Entry;
/// Return the type constraint that `test` (if true) would place on `symbol`, if any.
@@ -51,7 +52,7 @@ pub(crate) fn infer_narrowing_constraint<'db>(
db: &'db dyn Db,
predicate: Predicate<'db>,
place: ScopedPlaceId,
) -> Option<Type<'db>> {
) -> Option<NarrowingConstraint<'db>> {
let constraints = match predicate.node {
PredicateNode::Expression(expression) => {
if predicate.is_positive {
@@ -70,11 +71,8 @@ pub(crate) fn infer_narrowing_constraint<'db>(
PredicateNode::ReturnsNever(_) => return None,
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(constraints) = constraints {
constraints.get(&place).copied()
} else {
None
}
constraints.and_then(|constraints| constraints.get(&place).cloned())
}
#[salsa::tracked(returns(as_ref), heap_size=ruff_memory_usage::heap_size)]
@@ -269,6 +267,7 @@ impl ClassInfoConstraintFunction {
| Type::IntLiteral(_)
| Type::KnownInstance(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::WrapperDescriptor(_)
| Type::DataclassTransformer(_)
| Type::TypedDict(_)
@@ -277,48 +276,186 @@ impl ClassInfoConstraintFunction {
}
}
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
/// Represents narrowing constraints in Disjunctive Normal Form (DNF).
///
/// This is a disjunction (OR) of conjunctions (AND) of constraints.
/// The DNF representation allows us to properly track `TypeGuard` constraints
/// through boolean operations.
///
/// For example:
/// - `f(x) and g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]`
/// => and
/// ===> `NarrowingConstraint { regular_disjunct: Some(A), typeguard_disjuncts: [] }`
/// ===> `NarrowingConstraint { regular_disjunct: None, typeguard_disjuncts: [B] }`
/// => `NarrowingConstraint { regular_disjunct: None, typeguard_disjuncts: [B] }`
/// => evaluates to `B` (`TypeGuard` clobbers any previous type information)
///
/// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]`
/// => or
/// ===> `NarrowingConstraint { regular_disjunct: Some(A), typeguard_disjuncts: [] }`
/// ===> `NarrowingConstraint { regular_disjunct: None, typeguard_disjuncts: [B] }`
/// => `NarrowingConstraint { regular_disjunct: Some(A), typeguard_disjuncts: [B] }`
/// => evaluates to `(P & A) | B`, where `P` is our previously-known type
#[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)]
pub(crate) struct NarrowingConstraint<'db> {
/// Regular constraint (from narrowing comparisons or `TypeIs`). We can use a single type here
/// because we can eagerly union disjunctions and eagerly intersect conjunctions.
regular_disjunct: Option<Type<'db>>,
/// `TypeGuard` constraints. We can't eagerly union disjunctions because `TypeGuard` clobbers
/// the previously-known type; within each `TypeGuard` disjunct, we may eagerly intersect
/// conjunctions with a later regular narrowing.
typeguard_disjuncts: SmallVec<[Type<'db>; 1]>,
}
impl<'db> NarrowingConstraint<'db> {
/// Create a constraint from a regular (non-`TypeGuard`) type
pub(crate) fn regular(constraint: Type<'db>) -> Self {
Self {
regular_disjunct: Some(constraint),
typeguard_disjuncts: smallvec![],
}
}
/// Create a constraint from a `TypeGuard` type
fn typeguard(constraint: Type<'db>) -> Self {
Self {
regular_disjunct: None,
typeguard_disjuncts: smallvec![constraint],
}
}
/// Merge two constraints, taking their intersection but respecting `TypeGuard` semantics (with
/// `other` winning)
pub(crate) fn merge_constraint_and(&self, other: Self, db: &'db dyn Db) -> Self {
// Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...)
// becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ...
//
// In our representation, the RHS `typeguard_disjuncts` will all clobber the LHS disjuncts
// when they are anded, so they'll just stay as is.
//
// The thing we actually need to deal with is the RHS `regular_disjunct`. It gets
// intersected with the LHS `regular_disjunct` to form the new `regular_disjunct`, and
// intersected with each LHS `typeguard_disjunct` to form new additional
// `typeguard_disjuncts`.
let Some(other_regular_disjunct) = other.regular_disjunct else {
return other;
};
let new_regular_disjunct = self.regular_disjunct.map(|regular_disjunct| {
IntersectionBuilder::new(db)
.add_positive(regular_disjunct)
.add_positive(other_regular_disjunct)
.build()
});
let additional_typeguard_disjuncts =
self.typeguard_disjuncts.iter().map(|typeguard_disjunct| {
IntersectionBuilder::new(db)
.add_positive(*typeguard_disjunct)
.add_positive(other_regular_disjunct)
.build()
});
let mut new_typeguard_disjuncts = other.typeguard_disjuncts;
new_typeguard_disjuncts.extend(additional_typeguard_disjuncts);
NarrowingConstraint {
regular_disjunct: new_regular_disjunct,
typeguard_disjuncts: new_typeguard_disjuncts,
}
}
/// Evaluate the type this effectively constrains to
///
/// Forgets whether each constraint originated from a `TypeGuard` or not
pub(crate) fn evaluate_constraint_type(self, db: &'db dyn Db) -> Type<'db> {
UnionType::from_elements(
db,
self.typeguard_disjuncts
.into_iter()
.chain(self.regular_disjunct),
)
}
}
impl<'db> From<Type<'db>> for NarrowingConstraint<'db> {
fn from(constraint: Type<'db>) -> Self {
Self::regular(constraint)
}
}
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>;
/// Merge constraints with AND semantics (intersection/conjunction).
///
/// When we have `constraint1 & constraint2`, we need to distribute AND over the OR
/// in the DNF representations:
/// `(A | B) & (C | D)` becomes `(A & C) | (A & D) | (B & C) | (B & D)`
///
/// For each conjunction pair, we:
/// - Take the right conjunct if it has a `TypeGuard`
/// - Intersect the constraints normally otherwise
fn merge_constraints_and<'db>(
into: &mut NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
from: NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, value) in from {
match into.entry(*key) {
for (key, from_constraint) in from {
match into.entry(key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = IntersectionBuilder::new(db)
.add_positive(*entry.get())
.add_positive(*value)
.build();
let into_constraint = entry.get();
entry.insert(into_constraint.merge_constraint_and(from_constraint, db));
}
Entry::Vacant(entry) => {
entry.insert(*value);
entry.insert(from_constraint);
}
}
}
}
/// Merge constraints with OR semantics (union/disjunction).
///
/// When we have `constraint1 OR constraint2`, we simply concatenate the disjuncts
/// from both constraints: `(A | B) OR (C | D)` becomes `A | B | C | D`
///
/// However, if a place appears in only one branch of the OR, we need to widen it
/// to `object` in the overall result (because the other branch doesn't constrain it).
fn merge_constraints_or<'db>(
into: &mut NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
from: NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, value) in from {
match into.entry(*key) {
// For places that appear in `into` but not in `from`, widen to object
into.retain(|key, _| from.contains_key(key));
for (key, from_constraint) in from {
match into.entry(key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = UnionBuilder::new(db).add(*entry.get()).add(*value).build();
let into_constraint = entry.get_mut();
// Union the regular constraints
into_constraint.regular_disjunct = match (
into_constraint.regular_disjunct,
from_constraint.regular_disjunct,
) {
(Some(a), Some(b)) => Some(UnionType::from_elements(db, [a, b])),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
// Concatenate typeguard disjuncts
into_constraint
.typeguard_disjuncts
.extend(from_constraint.typeguard_disjuncts);
}
Entry::Vacant(entry) => {
entry.insert(Type::object());
Entry::Vacant(_) => {
// Place only appears in `from`, not in `into`. No constraint needed.
}
}
}
for (key, value) in into.iter_mut() {
if !from.contains_key(key) {
*value = Type::object();
}
}
}
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
@@ -384,7 +521,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
fn finish(mut self) -> Option<NarrowingConstraints<'db>> {
let constraints: Option<NarrowingConstraints<'db>> = match self.predicate {
let mut constraints: Option<NarrowingConstraints<'db>> = match self.predicate {
PredicateNode::Expression(expression) => {
self.evaluate_expression_predicate(expression, self.is_positive)
}
@@ -394,12 +531,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
PredicateNode::ReturnsNever(_) => return None,
PredicateNode::StarImportPlaceholder(_) => return None,
};
if let Some(mut constraints) = constraints {
if let Some(ref mut constraints) = constraints {
constraints.shrink_to_fit();
Some(constraints)
} else {
None
}
constraints
}
fn evaluate_expression_predicate(
@@ -586,7 +723,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
Type::AlwaysTruthy.negate(self.db)
};
Some(NarrowingConstraints::from_iter([(place, ty)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(ty),
)]))
}
fn evaluate_expr_named(
@@ -917,7 +1057,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.collect();
if filtered.len() < union.elements(self.db).len() {
let place = self.expect_place(&subscript_place_expr);
constraints.insert(place, UnionType::from_elements(self.db, filtered));
constraints.insert(
place,
NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)),
);
}
}
@@ -983,7 +1126,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
// As mentioned above, the synthesized `TypedDict` is always negated.
let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db);
let place = self.expect_place(&subscript_place_expr);
constraints.insert(place, intersection);
constraints.insert(place, NarrowingConstraint::regular(intersection));
}
}
@@ -1004,7 +1147,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
let place = self.expect_place(&left);
constraints.insert(place, ty);
constraints.insert(place, NarrowingConstraint::regular(ty));
}
}
ast::Expr::Call(ast::ExprCall {
@@ -1052,8 +1195,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let place = self.expect_place(&target);
constraints.insert(
place,
Type::instance(self.db, rhs_class.unknown_specialization(self.db))
.negate_if(self.db, !is_positive),
NarrowingConstraint::regular(
Type::instance(self.db, rhs_class.unknown_specialization(self.db))
.negate_if(self.db, !is_positive),
),
);
}
}
@@ -1080,21 +1225,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
None | Some(KnownFunction::RevealType)
) =>
{
let return_ty = inference.expression_type(expr_call);
let (guarded_ty, place) = match return_ty {
// TODO: TypeGuard
Type::TypeIs(type_is) => {
let (_, place) = type_is.place_info(self.db)?;
(type_is.return_type(self.db), place)
}
_ => return None,
};
Some(NarrowingConstraints::from_iter([(
place,
guarded_ty.negate_if(self.db, !is_positive),
)]))
self.evaluate_type_guard_call(inference, expr_call, is_positive)
}
Type::BoundMethod(_) => {
self.evaluate_type_guard_call(inference, expr_call, is_positive)
}
// For the expression `len(E)`, we narrow the type based on whether len(E) is truthy
// (i.e., whether E is non-empty). We only narrow the parts of the type where we know
@@ -1112,7 +1246,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) {
let target = place_expr(arg)?;
let place = self.expect_place(&target);
Some(NarrowingConstraints::from_iter([(place, narrowed_ty)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(narrowed_ty),
)]))
} else {
None
}
@@ -1142,7 +1279,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return Some(NarrowingConstraints::from_iter([(
place,
constraint.negate_if(self.db, !is_positive),
NarrowingConstraint::regular(constraint.negate_if(self.db, !is_positive)),
)]));
}
@@ -1155,7 +1292,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.map(|constraint| {
NarrowingConstraints::from_iter([(
place,
constraint.negate_if(self.db, !is_positive),
NarrowingConstraint::regular(
constraint.negate_if(self.db, !is_positive),
),
)])
})
}
@@ -1175,6 +1314,42 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}
// Helper to evaluate TypeGuard/TypeIs narrowing for a call expression.
// Used for both direct function calls and bound method calls.
fn evaluate_type_guard_call(
&mut self,
inference: &ExpressionInference<'db>,
expr_call: &ast::ExprCall,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let return_ty = inference.expression_type(expr_call);
let place_and_constraint = match return_ty {
Type::TypeIs(type_is) => {
let (_, place) = type_is.place_info(self.db)?;
Some((
place,
NarrowingConstraint::regular(
type_is
.return_type(self.db)
.negate_if(self.db, !is_positive),
),
))
}
// TypeGuard only narrows in the positive case
Type::TypeGuard(type_guard) if is_positive => {
let (_, place) = type_guard.place_info(self.db)?;
Some((
place,
NarrowingConstraint::typeguard(type_guard.return_type(self.db)),
))
}
_ => None,
}?;
Some(NarrowingConstraints::from_iter([place_and_constraint]))
}
fn evaluate_match_pattern_singleton(
&mut self,
subject: Expression<'db>,
@@ -1190,7 +1365,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::Singleton::False => Type::BooleanLiteral(false),
};
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(ty),
)]))
}
fn evaluate_match_pattern_class(
@@ -1223,7 +1401,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
_ => return None,
};
Some(NarrowingConstraints::from_iter([(place, narrowed_type)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(narrowed_type),
)]))
}
fn evaluate_match_pattern_value(
@@ -1243,7 +1424,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
.map(|ty| NarrowingConstraints::from_iter([(place, ty)]))
.map(|ty| NarrowingConstraints::from_iter([(place, NarrowingConstraint::regular(ty))]))
}
fn evaluate_match_pattern_or(
@@ -1267,7 +1448,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
self.evaluate_pattern_predicate_kind(predicate, subject, is_positive)
})
.reduce(|mut constraints, constraints_| {
merge_constraints(&mut constraints, &constraints_, db);
merge_constraints(&mut constraints, constraints_, db);
constraints
})
}
@@ -1279,7 +1460,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression, TypeContext::default());
let mut sub_constraints = expr_bool_op
let sub_constraints = expr_bool_op
.values
.iter()
// filter our arms with statically known truthiness
@@ -1299,7 +1480,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let mut aggregation: Option<NarrowingConstraints> = None;
for sub_constraint in sub_constraints.into_iter().flatten() {
if let Some(ref mut some_aggregation) = aggregation {
merge_constraints_and(some_aggregation, &sub_constraint, self.db);
merge_constraints_and(some_aggregation, sub_constraint, self.db);
} else {
aggregation = Some(sub_constraint);
}
@@ -1307,8 +1488,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
aggregation
}
(BoolOp::Or, true) | (BoolOp::And, false) => {
let (first, rest) = sub_constraints.split_first_mut()?;
if let Some(first) = first {
let (mut first, rest) = {
let mut it = sub_constraints.into_iter();
(it.next()?, it)
};
if let Some(ref mut first) = first {
for rest_constraint in rest {
if let Some(rest_constraint) = rest_constraint {
merge_constraints_or(first, rest_constraint, self.db);
@@ -1317,7 +1502,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}
}
first.clone()
first
}
}
}

View File

@@ -5,7 +5,8 @@ use salsa::plumbing::AsId;
use crate::{db::Db, types::bound_super::SuperOwnerKind};
use super::{
DynamicType, TodoType, Type, TypeIsType, class_base::ClassBase, subclass_of::SubclassOfInner,
DynamicType, TodoType, Type, TypeGuardLike, TypeGuardType, TypeIsType, class_base::ClassBase,
subclass_of::SubclassOfInner,
};
/// Return an [`Ordering`] that describes the canonical order in which two types should appear
@@ -132,6 +133,10 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
(Type::TypeIs(_), _) => Ordering::Less,
(_, Type::TypeIs(_)) => Ordering::Greater,
(Type::TypeGuard(left), Type::TypeGuard(right)) => typeguard_ordering(db, *left, *right),
(Type::TypeGuard(_), _) => Ordering::Less,
(_, Type::TypeGuard(_)) => Ordering::Greater,
(Type::NominalInstance(left), Type::NominalInstance(right)) => {
left.class(db).cmp(&right.class(db))
}
@@ -286,13 +291,13 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering
}
}
/// Determine a canonical order for two instances of [`TypeIsType`].
/// Generic helper for ordering type guard-like types.
///
/// The following criteria are considered, in order:
/// * Boundness: Unbound precedes bound
/// * Symbol name: String comparison
/// * Guarded type: [`union_or_intersection_elements_ordering`]
fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering {
fn guard_like_ordering<'db, T: TypeGuardLike<'db>>(db: &'db dyn Db, left: T, right: T) -> Ordering {
let (left_ty, right_ty) = (left.return_type(db), right.return_type(db));
match (left.place_info(db), right.place_info(db)) {
@@ -307,3 +312,13 @@ fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering
},
}
}
/// Determine a canonical order for two instances of [`TypeIsType`].
fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering {
guard_like_ordering(db, left, right)
}
/// Determine a canonical order for two instances of [`TypeGuardType`].
fn typeguard_ordering(db: &dyn Db, left: TypeGuardType, right: TypeGuardType) -> Ordering {
guard_like_ordering(db, left, right)
}

View File

@@ -4,7 +4,7 @@ use crate::{
BoundMethodType, BoundSuperType, BoundTypeVarInstance, CallableType, GenericAlias,
IntersectionType, KnownBoundMethodType, KnownInstanceType, NominalInstanceType,
PropertyInstanceType, ProtocolInstanceType, SubclassOfType, Type, TypeAliasType,
TypeIsType, TypeVarInstance, TypedDictType, UnionType,
TypeGuardType, TypeIsType, TypeVarInstance, TypedDictType, UnionType,
bound_super::walk_bound_super_type,
class::walk_generic_alias,
function::{FunctionType, walk_function_type},
@@ -14,7 +14,7 @@ use crate::{
walk_bound_method_type, walk_bound_type_var_type, walk_callable_type,
walk_intersection_type, walk_known_instance_type, walk_method_wrapper_type,
walk_property_instance_type, walk_type_alias_type, walk_type_var_type,
walk_typed_dict_type, walk_typeis_type, walk_union,
walk_typed_dict_type, walk_typeguard_type, walk_typeis_type, walk_union,
},
};
use std::cell::{Cell, RefCell};
@@ -50,6 +50,10 @@ pub(crate) trait TypeVisitor<'db> {
walk_typeis_type(db, type_is, self);
}
fn visit_typeguard_type(&self, db: &'db dyn Db, type_is: TypeGuardType<'db>) {
walk_typeguard_type(db, type_is, self);
}
fn visit_subclass_of_type(&self, db: &'db dyn Db, subclass_of: SubclassOfType<'db>) {
walk_subclass_of_type(db, subclass_of, self);
}
@@ -127,6 +131,7 @@ pub(super) enum NonAtomicType<'db> {
NominalInstance(NominalInstanceType<'db>),
PropertyInstance(PropertyInstanceType<'db>),
TypeIs(TypeIsType<'db>),
TypeGuard(TypeGuardType<'db>),
TypeVar(BoundTypeVarInstance<'db>),
ProtocolInstance(ProtocolInstanceType<'db>),
TypedDict(TypedDictType<'db>),
@@ -195,6 +200,9 @@ impl<'db> From<Type<'db>> for TypeKind<'db> {
TypeKind::NonAtomic(NonAtomicType::TypeVar(bound_typevar))
}
Type::TypeIs(type_is) => TypeKind::NonAtomic(NonAtomicType::TypeIs(type_is)),
Type::TypeGuard(type_guard) => {
TypeKind::NonAtomic(NonAtomicType::TypeGuard(type_guard))
}
Type::TypedDict(typed_dict) => {
TypeKind::NonAtomic(NonAtomicType::TypedDict(typed_dict))
}
@@ -233,6 +241,7 @@ pub(super) fn walk_non_atomic_type<'db, V: TypeVisitor<'db> + ?Sized>(
visitor.visit_property_instance_type(db, property);
}
NonAtomicType::TypeIs(type_is) => visitor.visit_typeis_type(db, type_is),
NonAtomicType::TypeGuard(type_guard) => visitor.visit_typeguard_type(db, type_guard),
NonAtomicType::TypeVar(bound_typevar) => {
visitor.visit_bound_type_var_type(db, bound_typevar);
}