[ty] Fix disjointness checks with type-of `@final` classes (#21770)

## Summary

We currently perform a subtyping check, similar to what we were doing
for `@final` instances before
https://github.com/astral-sh/ruff/pull/21167, which is incorrect, e.g.
we currently consider `type[X[Any]]` and `type[X[T]]]` disjoint (where
`X` is `@final`).
This commit is contained in:
Ibraheem Ahmed 2025-12-10 15:15:10 -05:00 committed by GitHub
parent 3e00221a6c
commit a2fb2ee06c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 218 additions and 32 deletions

View File

@ -92,8 +92,7 @@ def f(x: A[int] | B):
reveal_type(x) # revealed: A[int] | B reveal_type(x) # revealed: A[int] | B
if type(x) is A: if type(x) is A:
# TODO: this should be `A[int]`, but `A[int] | B` would be better than `Never` reveal_type(x) # revealed: A[int]
reveal_type(x) # revealed: Never
else: else:
reveal_type(x) # revealed: A[int] | B reveal_type(x) # revealed: A[int] | B
@ -111,8 +110,7 @@ def f(x: A[int] | B):
if type(x) is not A: if type(x) is not A:
reveal_type(x) # revealed: A[int] | B reveal_type(x) # revealed: A[int] | B
else: else:
# TODO: this should be `A[int]`, but `A[int] | B` would be better than `Never` reveal_type(x) # revealed: A[int]
reveal_type(x) # revealed: Never
if type(x) is not B: if type(x) is not B:
reveal_type(x) # revealed: A[int] | B reveal_type(x) # revealed: A[int] | B

View File

@ -160,7 +160,7 @@ same also applies to enum classes with members, which are implicitly final:
```toml ```toml
[environment] [environment]
python-version = "3.10" python-version = "3.12"
``` ```
```py ```py
@ -180,3 +180,177 @@ def _(x: type[Foo], y: type[EllipsisType], z: type[Answer]):
reveal_type(y) # revealed: <class 'EllipsisType'> reveal_type(y) # revealed: <class 'EllipsisType'>
reveal_type(z) # revealed: <class 'Answer'> reveal_type(z) # revealed: <class 'Answer'>
``` ```
## Subtyping `@final` classes
```toml
[environment]
python-version = "3.12"
```
```py
from typing import final, Any
from ty_extensions import is_assignable_to, is_subtype_of, is_disjoint_from, static_assert
class Biv[T]: ...
class Cov[T]:
def pop(self) -> T:
raise NotImplementedError
class Contra[T]:
def push(self, value: T) -> None:
pass
class Inv[T]:
x: T
@final
class BivSub[T](Biv[T]): ...
@final
class CovSub[T](Cov[T]): ...
@final
class ContraSub[T](Contra[T]): ...
@final
class InvSub[T](Inv[T]): ...
def _[T, U]():
static_assert(is_subtype_of(type[BivSub[T]], type[BivSub[U]]))
static_assert(not is_disjoint_from(type[BivSub[U]], type[BivSub[T]]))
# `T` and `U` could specialize to the same type.
static_assert(not is_subtype_of(type[CovSub[T]], type[CovSub[U]]))
static_assert(not is_disjoint_from(type[CovSub[U]], type[CovSub[T]]))
static_assert(not is_subtype_of(type[ContraSub[T]], type[ContraSub[U]]))
static_assert(not is_disjoint_from(type[ContraSub[U]], type[ContraSub[T]]))
static_assert(not is_subtype_of(type[InvSub[T]], type[InvSub[U]]))
static_assert(not is_disjoint_from(type[InvSub[U]], type[InvSub[T]]))
def _():
static_assert(is_subtype_of(type[BivSub[bool]], type[BivSub[int]]))
static_assert(is_subtype_of(type[BivSub[int]], type[BivSub[bool]]))
static_assert(not is_disjoint_from(type[BivSub[bool]], type[BivSub[int]]))
# `BivSub[int]` and `BivSub[str]` are mutual subtypes.
static_assert(not is_disjoint_from(type[BivSub[int]], type[BivSub[str]]))
static_assert(is_subtype_of(type[CovSub[bool]], type[CovSub[int]]))
static_assert(not is_subtype_of(type[CovSub[int]], type[CovSub[bool]]))
static_assert(not is_disjoint_from(type[CovSub[bool]], type[CovSub[int]]))
# `CovSub[Never]` is a subtype of both `CovSub[int]` and `CovSub[str]`.
static_assert(not is_disjoint_from(type[CovSub[int]], type[CovSub[str]]))
static_assert(not is_subtype_of(type[ContraSub[bool]], type[ContraSub[int]]))
static_assert(is_subtype_of(type[ContraSub[int]], type[ContraSub[bool]]))
static_assert(not is_disjoint_from(type[ContraSub[bool]], type[ContraSub[int]]))
# `ContraSub[int | str]` is a subtype of both `ContraSub[int]` and `ContraSub[str]`.
static_assert(not is_disjoint_from(type[ContraSub[int]], type[ContraSub[str]]))
static_assert(not is_subtype_of(type[InvSub[bool]], type[InvSub[int]]))
static_assert(not is_subtype_of(type[InvSub[int]], type[InvSub[bool]]))
static_assert(is_disjoint_from(type[InvSub[int]], type[InvSub[str]]))
# TODO: These are disjoint.
static_assert(not is_disjoint_from(type[InvSub[bool]], type[InvSub[int]]))
def _[T]():
static_assert(is_subtype_of(type[BivSub[T]], type[BivSub[Any]]))
static_assert(is_subtype_of(type[BivSub[Any]], type[BivSub[T]]))
static_assert(is_assignable_to(type[BivSub[T]], type[BivSub[Any]]))
static_assert(is_assignable_to(type[BivSub[Any]], type[BivSub[T]]))
static_assert(not is_disjoint_from(type[BivSub[T]], type[BivSub[Any]]))
static_assert(not is_subtype_of(type[CovSub[T]], type[CovSub[Any]]))
static_assert(not is_subtype_of(type[CovSub[Any]], type[CovSub[T]]))
static_assert(is_assignable_to(type[CovSub[T]], type[CovSub[Any]]))
static_assert(is_assignable_to(type[CovSub[Any]], type[CovSub[T]]))
static_assert(not is_disjoint_from(type[CovSub[T]], type[CovSub[Any]]))
static_assert(not is_subtype_of(type[ContraSub[T]], type[ContraSub[Any]]))
static_assert(not is_subtype_of(type[ContraSub[Any]], type[ContraSub[T]]))
static_assert(is_assignable_to(type[ContraSub[T]], type[ContraSub[Any]]))
static_assert(is_assignable_to(type[ContraSub[Any]], type[ContraSub[T]]))
static_assert(not is_disjoint_from(type[ContraSub[T]], type[ContraSub[Any]]))
static_assert(not is_subtype_of(type[InvSub[T]], type[InvSub[Any]]))
static_assert(not is_subtype_of(type[InvSub[Any]], type[InvSub[T]]))
static_assert(is_assignable_to(type[InvSub[T]], type[InvSub[Any]]))
static_assert(is_assignable_to(type[InvSub[Any]], type[InvSub[T]]))
static_assert(not is_disjoint_from(type[InvSub[T]], type[InvSub[Any]]))
def _[T, U]():
static_assert(is_subtype_of(type[BivSub[T]], type[Biv[T]]))
static_assert(not is_subtype_of(type[Biv[T]], type[BivSub[T]]))
static_assert(not is_disjoint_from(type[BivSub[T]], type[Biv[T]]))
static_assert(not is_disjoint_from(type[BivSub[U]], type[Biv[T]]))
static_assert(not is_disjoint_from(type[BivSub[U]], type[Biv[U]]))
static_assert(is_subtype_of(type[CovSub[T]], type[Cov[T]]))
static_assert(not is_subtype_of(type[Cov[T]], type[CovSub[T]]))
static_assert(not is_disjoint_from(type[CovSub[T]], type[Cov[T]]))
static_assert(not is_disjoint_from(type[CovSub[U]], type[Cov[T]]))
static_assert(not is_disjoint_from(type[CovSub[U]], type[Cov[U]]))
static_assert(is_subtype_of(type[ContraSub[T]], type[Contra[T]]))
static_assert(not is_subtype_of(type[Contra[T]], type[ContraSub[T]]))
static_assert(not is_disjoint_from(type[ContraSub[T]], type[Contra[T]]))
static_assert(not is_disjoint_from(type[ContraSub[U]], type[Contra[T]]))
static_assert(not is_disjoint_from(type[ContraSub[U]], type[Contra[U]]))
static_assert(is_subtype_of(type[InvSub[T]], type[Inv[T]]))
static_assert(not is_subtype_of(type[Inv[T]], type[InvSub[T]]))
static_assert(not is_disjoint_from(type[InvSub[T]], type[Inv[T]]))
static_assert(not is_disjoint_from(type[InvSub[U]], type[Inv[T]]))
static_assert(not is_disjoint_from(type[InvSub[U]], type[Inv[U]]))
def _():
static_assert(is_subtype_of(type[BivSub[bool]], type[Biv[int]]))
static_assert(is_subtype_of(type[BivSub[int]], type[Biv[bool]]))
static_assert(not is_disjoint_from(type[BivSub[bool]], type[Biv[int]]))
static_assert(not is_disjoint_from(type[BivSub[int]], type[Biv[bool]]))
static_assert(is_subtype_of(type[CovSub[bool]], type[Cov[int]]))
static_assert(not is_subtype_of(type[CovSub[int]], type[Cov[bool]]))
static_assert(not is_disjoint_from(type[CovSub[bool]], type[Cov[int]]))
static_assert(not is_disjoint_from(type[CovSub[int]], type[Cov[bool]]))
static_assert(not is_subtype_of(type[ContraSub[bool]], type[Contra[int]]))
static_assert(is_subtype_of(type[ContraSub[int]], type[Contra[bool]]))
static_assert(not is_disjoint_from(type[ContraSub[int]], type[Contra[bool]]))
static_assert(not is_disjoint_from(type[ContraSub[bool]], type[Contra[int]]))
static_assert(not is_subtype_of(type[InvSub[bool]], type[Inv[int]]))
static_assert(not is_subtype_of(type[InvSub[int]], type[Inv[bool]]))
# TODO: These are disjoint.
static_assert(not is_disjoint_from(type[InvSub[bool]], type[Inv[int]]))
# TODO: These are disjoint.
static_assert(not is_disjoint_from(type[InvSub[int]], type[Inv[bool]]))
def _[T]():
static_assert(is_subtype_of(type[BivSub[T]], type[Biv[Any]]))
static_assert(is_subtype_of(type[BivSub[Any]], type[Biv[T]]))
static_assert(is_assignable_to(type[BivSub[T]], type[Biv[Any]]))
static_assert(is_assignable_to(type[BivSub[Any]], type[Biv[T]]))
static_assert(not is_disjoint_from(type[BivSub[T]], type[Biv[Any]]))
static_assert(not is_subtype_of(type[CovSub[T]], type[Cov[Any]]))
static_assert(not is_subtype_of(type[CovSub[Any]], type[Cov[T]]))
static_assert(is_assignable_to(type[CovSub[T]], type[Cov[Any]]))
static_assert(is_assignable_to(type[CovSub[Any]], type[Cov[T]]))
static_assert(not is_disjoint_from(type[CovSub[T]], type[Cov[Any]]))
static_assert(not is_subtype_of(type[ContraSub[T]], type[Contra[Any]]))
static_assert(not is_subtype_of(type[ContraSub[Any]], type[Contra[T]]))
static_assert(is_assignable_to(type[ContraSub[T]], type[Contra[Any]]))
static_assert(is_assignable_to(type[ContraSub[Any]], type[Contra[T]]))
static_assert(not is_disjoint_from(type[ContraSub[T]], type[Contra[Any]]))
static_assert(not is_subtype_of(type[InvSub[T]], type[Inv[Any]]))
static_assert(not is_subtype_of(type[InvSub[Any]], type[Inv[T]]))
static_assert(is_assignable_to(type[InvSub[T]], type[Inv[Any]]))
static_assert(is_assignable_to(type[InvSub[Any]], type[Inv[T]]))
static_assert(not is_disjoint_from(type[InvSub[T]], type[Inv[Any]]))
```

View File

@ -684,9 +684,8 @@ class GenericClass[T]:
x: T # invariant x: T # invariant
static_assert(not is_disjoint_from(TypeOf[GenericClass], type[GenericClass])) static_assert(not is_disjoint_from(TypeOf[GenericClass], type[GenericClass]))
# TODO: these should not error static_assert(not is_disjoint_from(TypeOf[GenericClass[int]], type[GenericClass]))
static_assert(not is_disjoint_from(TypeOf[GenericClass[int]], type[GenericClass])) # error: [static-assert-error] static_assert(not is_disjoint_from(TypeOf[GenericClass], type[GenericClass[int]]))
static_assert(not is_disjoint_from(TypeOf[GenericClass], type[GenericClass[int]])) # error: [static-assert-error]
static_assert(not is_disjoint_from(TypeOf[GenericClass[int]], type[GenericClass[int]])) static_assert(not is_disjoint_from(TypeOf[GenericClass[int]], type[GenericClass[int]]))
static_assert(is_disjoint_from(TypeOf[GenericClass[str]], type[GenericClass[int]])) static_assert(is_disjoint_from(TypeOf[GenericClass[str]], type[GenericClass[int]]))
@ -694,19 +693,17 @@ class GenericClassIntBound[T: int]:
x: T # invariant x: T # invariant
static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound], type[GenericClassIntBound])) static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound], type[GenericClassIntBound]))
# TODO: these should not error static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound[int]], type[GenericClassIntBound]))
static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound[int]], type[GenericClassIntBound])) # error: [static-assert-error] static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound], type[GenericClassIntBound[int]]))
static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound], type[GenericClassIntBound[int]])) # error: [static-assert-error]
static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound[int]], type[GenericClassIntBound[int]])) static_assert(not is_disjoint_from(TypeOf[GenericClassIntBound[int]], type[GenericClassIntBound[int]]))
@final @final
class GenericFinalClass[T]: class GenericFinalClass[T]:
x: T # invariant x: T # invariant
# TODO: these should not error static_assert(not is_disjoint_from(TypeOf[GenericFinalClass], type[GenericFinalClass]))
static_assert(not is_disjoint_from(TypeOf[GenericFinalClass], type[GenericFinalClass])) # error: [static-assert-error] static_assert(not is_disjoint_from(TypeOf[GenericFinalClass[int]], type[GenericFinalClass]))
static_assert(not is_disjoint_from(TypeOf[GenericFinalClass[int]], type[GenericFinalClass])) # error: [static-assert-error] static_assert(not is_disjoint_from(TypeOf[GenericFinalClass], type[GenericFinalClass[int]]))
static_assert(not is_disjoint_from(TypeOf[GenericFinalClass], type[GenericFinalClass[int]])) # error: [static-assert-error]
static_assert(not is_disjoint_from(TypeOf[GenericFinalClass[int]], type[GenericFinalClass[int]])) static_assert(not is_disjoint_from(TypeOf[GenericFinalClass[int]], type[GenericFinalClass[int]]))
static_assert(is_disjoint_from(TypeOf[GenericFinalClass[str]], type[GenericFinalClass[int]])) static_assert(is_disjoint_from(TypeOf[GenericFinalClass[str]], type[GenericFinalClass[int]]))
``` ```

View File

@ -3353,7 +3353,6 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(..) | Type::WrapperDescriptor(..)
| Type::ModuleLiteral(..) | Type::ModuleLiteral(..)
| Type::ClassLiteral(..) | Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::SpecialForm(..) | Type::SpecialForm(..)
| Type::KnownInstance(..)), | Type::KnownInstance(..)),
right @ (Type::BooleanLiteral(..) right @ (Type::BooleanLiteral(..)
@ -3367,7 +3366,6 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(..) | Type::WrapperDescriptor(..)
| Type::ModuleLiteral(..) | Type::ModuleLiteral(..)
| Type::ClassLiteral(..) | Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::SpecialForm(..) | Type::SpecialForm(..)
| Type::KnownInstance(..)), | Type::KnownInstance(..)),
) => ConstraintSet::from(left != right), ) => ConstraintSet::from(left != right),
@ -3550,13 +3548,39 @@ impl<'db> Type<'db> {
ConstraintSet::from(true) ConstraintSet::from(true)
} }
(Type::GenericAlias(left_alias), Type::GenericAlias(right_alias)) => {
ConstraintSet::from(left_alias.origin(db) != right_alias.origin(db)).or(db, || {
left_alias.specialization(db).is_disjoint_from_impl(
db,
right_alias.specialization(db),
inferable,
disjointness_visitor,
relation_visitor,
)
})
}
(Type::ClassLiteral(class_literal), other @ Type::GenericAlias(_))
| (other @ Type::GenericAlias(_), Type::ClassLiteral(class_literal)) => class_literal
.default_specialization(db)
.into_generic_alias()
.when_none_or(|alias| {
other.is_disjoint_from_impl(
db,
Type::GenericAlias(alias),
inferable,
disjointness_visitor,
relation_visitor,
)
}),
(Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b)) (Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b))
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => { | (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() { match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => ConstraintSet::from(false), SubclassOfInner::Dynamic(_) => ConstraintSet::from(false),
SubclassOfInner::Class(class_a) => { SubclassOfInner::Class(class_a) => ConstraintSet::from(
class_b.when_subclass_of(db, None, class_a).negate(db) !class_a.could_exist_in_mro_of(db, ClassType::NonGeneric(class_b)),
} ),
SubclassOfInner::TypeVar(_) => unreachable!(), SubclassOfInner::TypeVar(_) => unreachable!(),
} }
} }
@ -3565,9 +3589,9 @@ impl<'db> Type<'db> {
| (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => { | (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() { match subclass_of_ty.subclass_of() {
SubclassOfInner::Dynamic(_) => ConstraintSet::from(false), SubclassOfInner::Dynamic(_) => ConstraintSet::from(false),
SubclassOfInner::Class(class_a) => ClassType::from(alias_b) SubclassOfInner::Class(class_a) => ConstraintSet::from(
.when_subclass_of(db, class_a, inferable) !class_a.could_exist_in_mro_of(db, ClassType::Generic(alias_b)),
.negate(db), ),
SubclassOfInner::TypeVar(_) => unreachable!(), SubclassOfInner::TypeVar(_) => unreachable!(),
} }
} }
@ -3861,6 +3885,8 @@ impl<'db> Type<'db> {
relation_visitor, relation_visitor,
) )
} }
(Type::GenericAlias(_), _) | (_, Type::GenericAlias(_)) => ConstraintSet::from(true),
} }
} }

View File

@ -1911,15 +1911,6 @@ impl<'db> ClassLiteral<'db> {
.contains(&ClassBase::Class(other)) .contains(&ClassBase::Class(other))
} }
pub(super) fn when_subclass_of(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
other: ClassType<'db>,
) -> ConstraintSet<'db> {
ConstraintSet::from(self.is_subclass_of(db, specialization, other))
}
/// Return `true` if this class constitutes a typed dict specification (inherits from /// Return `true` if this class constitutes a typed dict specification (inherits from
/// `typing.TypedDict`, either directly or indirectly). /// `typing.TypedDict`, either directly or indirectly).
#[salsa::tracked(cycle_initial=is_typed_dict_cycle_initial, #[salsa::tracked(cycle_initial=is_typed_dict_cycle_initial,