[ty] Fix subtyping with `type[T]` and unions (#21740)

## Summary

Resolves
https://github.com/astral-sh/ruff/pull/21685#issuecomment-3591695954.
This commit is contained in:
Ibraheem Ahmed 2025-12-01 18:20:13 -05:00 committed by GitHub
parent edc6ed5077
commit ec854c7199
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 54 deletions

View File

@ -123,11 +123,11 @@ class A:
A class `A` is a subtype of `type[T]` if any instance of `A` is a subtype of `T`.
```py
from typing import Callable, Protocol
from typing import Any, Callable, Protocol
from ty_extensions import is_assignable_to, is_subtype_of, is_disjoint_from, static_assert
class IntCallback(Protocol):
def __call__(self, *args, **kwargs) -> int: ...
class Callback[T](Protocol):
def __call__(self, *args, **kwargs) -> T: ...
def _[T](_: T):
static_assert(not is_subtype_of(type[T], T))
@ -141,8 +141,11 @@ def _[T](_: T):
static_assert(is_assignable_to(type[T], Callable[..., T]))
static_assert(not is_disjoint_from(type[T], Callable[..., T]))
static_assert(not is_assignable_to(type[T], IntCallback))
static_assert(not is_disjoint_from(type[T], IntCallback))
static_assert(is_assignable_to(type[T], Callable[..., T] | Callable[..., Any]))
static_assert(not is_disjoint_from(type[T], Callable[..., T] | Callable[..., Any]))
static_assert(not is_assignable_to(type[T], Callback[int]))
static_assert(not is_disjoint_from(type[T], Callback[int]))
def _[T: int](_: T):
static_assert(not is_subtype_of(type[T], T))
@ -157,14 +160,23 @@ def _[T: int](_: T):
static_assert(is_subtype_of(type[T], type[int]))
static_assert(not is_disjoint_from(type[T], type[int]))
static_assert(is_subtype_of(type[T], type[int] | None))
static_assert(not is_disjoint_from(type[T], type[int] | None))
static_assert(is_subtype_of(type[T], type[T]))
static_assert(not is_disjoint_from(type[T], type[T]))
static_assert(is_assignable_to(type[T], Callable[..., T]))
static_assert(not is_disjoint_from(type[T], Callable[..., T]))
static_assert(is_assignable_to(type[T], IntCallback))
static_assert(not is_disjoint_from(type[T], IntCallback))
static_assert(is_assignable_to(type[T], Callable[..., T] | Callable[..., Any]))
static_assert(not is_disjoint_from(type[T], Callable[..., T] | Callable[..., Any]))
static_assert(is_assignable_to(type[T], Callback[int]))
static_assert(not is_disjoint_from(type[T], Callback[int]))
static_assert(is_assignable_to(type[T], Callback[int] | Callback[Any]))
static_assert(not is_disjoint_from(type[T], Callback[int] | Callback[Any]))
static_assert(is_subtype_of(type[T], type[T] | None))
static_assert(not is_disjoint_from(type[T], type[T] | None))
@ -183,8 +195,14 @@ def _[T: (int, str)](_: T):
static_assert(is_assignable_to(type[T], Callable[..., T]))
static_assert(not is_disjoint_from(type[T], Callable[..., T]))
static_assert(not is_assignable_to(type[T], IntCallback))
static_assert(not is_disjoint_from(type[T], IntCallback))
static_assert(is_assignable_to(type[T], Callable[..., T] | Callable[..., Any]))
static_assert(not is_disjoint_from(type[T], Callable[..., T] | Callable[..., Any]))
static_assert(not is_assignable_to(type[T], Callback[int]))
static_assert(not is_disjoint_from(type[T], Callback[int]))
static_assert(is_assignable_to(type[T], Callback[int | str]))
static_assert(not is_disjoint_from(type[T], Callback[int] | Callback[str]))
static_assert(is_subtype_of(type[T], type[T] | None))
static_assert(not is_disjoint_from(type[T], type[T] | None))

View File

@ -2089,18 +2089,25 @@ impl<'db> Type<'db> {
// `type[T]` is a subtype of the class object `A` if every instance of `T` is a subtype of an instance
// of `A`, and vice versa.
(Type::SubclassOf(subclass_of), _)
if subclass_of.is_type_var()
&& !matches!(target, Type::Callable(_) | Type::ProtocolInstance(_)) =>
if !subclass_of
.into_type_var()
.zip(target.to_instance(db))
.when_some_and(|(this_instance, other_instance)| {
Type::TypeVar(this_instance).has_relation_to_impl(
db,
other_instance,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
})
.is_never_satisfied(db) =>
{
// TODO: The repetition here isn't great, but we really need the fallthrough logic,
// where this arm only engages if it returns true.
let this_instance = Type::TypeVar(subclass_of.into_type_var().unwrap());
let other_instance = match target {
Type::Union(union) => Some(
union.map(db, |element| element.to_instance(db).unwrap_or(Type::Never)),
),
_ => target.to_instance(db),
};
other_instance.when_some_and(|other_instance| {
target.to_instance(db).when_some_and(|other_instance| {
this_instance.has_relation_to_impl(
db,
other_instance,
@ -2111,6 +2118,7 @@ impl<'db> Type<'db> {
)
})
}
(_, Type::SubclassOf(subclass_of)) if subclass_of.is_type_var() => {
let other_instance = Type::TypeVar(subclass_of.into_type_var().unwrap());
self.to_instance(db).when_some_and(|this_instance| {
@ -2647,6 +2655,10 @@ impl<'db> Type<'db> {
disjointness_visitor,
),
(Type::SubclassOf(subclass_of), _) if subclass_of.is_type_var() => {
ConstraintSet::from(false)
}
// `Literal[<class 'C'>]` is a subtype of `type[B]` if `C` is a subclass of `B`,
// since `type[B]` describes all possible runtime subclasses of the class object `B`.
(Type::ClassLiteral(class), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty
@ -3081,8 +3093,7 @@ impl<'db> Type<'db> {
ConstraintSet::from(false)
}
// `type[T]` is disjoint from a callable or protocol instance if its upper bound or
// constraints are.
// `type[T]` is disjoint from a callable or protocol instance if its upper bound or constraints are.
(Type::SubclassOf(subclass_of), Type::Callable(_) | Type::ProtocolInstance(_))
| (Type::Callable(_) | Type::ProtocolInstance(_), Type::SubclassOf(subclass_of))
if subclass_of.is_type_var() =>
@ -3104,13 +3115,14 @@ impl<'db> Type<'db> {
// `type[T]` is disjoint from a class object `A` if every instance of `T` is disjoint from an instance of `A`.
(Type::SubclassOf(subclass_of), other) | (other, Type::SubclassOf(subclass_of))
if subclass_of.is_type_var() =>
if subclass_of.is_type_var()
&& (other.to_instance(db).is_some()
|| other.as_typevar().is_some_and(|type_var| {
type_var.typevar(db).bound_or_constraints(db).is_none()
})) =>
{
let this_instance = Type::TypeVar(subclass_of.into_type_var().unwrap());
let other_instance = match other {
Type::Union(union) => Some(
union.map(db, |element| element.to_instance(db).unwrap_or(Type::Never)),
),
// An unbounded typevar `U` may have instances of type `object` if specialized to
// an instance of `type`.
Type::TypeVar(typevar)
@ -3464,6 +3476,12 @@ impl<'db> Type<'db> {
})
}
(Type::SubclassOf(subclass_of_ty), _) | (_, Type::SubclassOf(subclass_of_ty))
if subclass_of_ty.is_type_var() =>
{
ConstraintSet::from(true)
}
(Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b))
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() {
@ -3493,31 +3511,27 @@ impl<'db> Type<'db> {
// for `type[Any]`/`type[Unknown]`/`type[Todo]`, we know the type cannot be any larger than `type`,
// so although the type is dynamic we can still determine disjointedness in some situations
(Type::SubclassOf(subclass_of_ty), other)
| (other, Type::SubclassOf(subclass_of_ty))
if !subclass_of_ty.is_type_var() =>
{
match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => {
KnownClass::Type.to_instance(db).is_disjoint_from_impl(
db,
other,
inferable,
disjointness_visitor,
relation_visitor,
)
}
SubclassOfInner::Class(class) => {
class.metaclass_instance_type(db).is_disjoint_from_impl(
db,
other,
inferable,
disjointness_visitor,
relation_visitor,
)
}
SubclassOfInner::TypeVar(_) => unreachable!(),
| (other, Type::SubclassOf(subclass_of_ty)) => match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => {
KnownClass::Type.to_instance(db).is_disjoint_from_impl(
db,
other,
inferable,
disjointness_visitor,
relation_visitor,
)
}
}
SubclassOfInner::Class(class) => {
class.metaclass_instance_type(db).is_disjoint_from_impl(
db,
other,
inferable,
disjointness_visitor,
relation_visitor,
)
}
SubclassOfInner::TypeVar(_) => unreachable!(),
},
(Type::SpecialForm(special_form), Type::NominalInstance(instance))
| (Type::NominalInstance(instance), Type::SpecialForm(special_form)) => {
@ -3779,11 +3793,6 @@ impl<'db> Type<'db> {
relation_visitor,
)
}
(Type::SubclassOf(_), _) | (_, Type::SubclassOf(_)) => {
// All cases should have been handled above.
unreachable!()
}
}
}