Fix stack overflow with recursive generic protocols

This fixes https://github.com/astral-sh/ty/issues/1736 where recursive
generic protocols with growing specializations caused a stack overflow.

The issue occurred with protocols like:
```python
class C[T](Protocol):
    a: 'C[set[T]]'
```

When checking `C[set[int]]` against `C[Unknown]`, member `a` requires
checking `C[set[set[int]]]`, which requires `C[set[set[set[int]]]]`,
etc. Each level has different type specializations, so the existing
cycle detection (using full types as cache keys) didn't catch the
infinite recursion.

The fix introduces `TypeRelationKey`, an enum that can be either a full
`Type` or a `ClassLiteral` (protocol class without specialization). For
protocol-to-protocol comparisons, we use `ClassLiteral` keys, which
detects when we're comparing the same protocol class regardless of
specialization. When a cycle is detected, we return the fallback value
(assume compatible) to safely terminate the recursion.
This commit is contained in:
Carl Meyer 2025-12-08 18:25:11 -08:00
parent 4e67a219bb
commit c88e1e40ab
No known key found for this signature in database
GPG Key ID: 2D1FB7916A52E121
3 changed files with 120 additions and 22 deletions

View File

@ -3010,6 +3010,31 @@ class Bar(Protocol[S]):
z: S | Bar[S]
```
### Recursive generic protocols with growing specializations
This snippet caused a stack overflow in <https://github.com/astral-sh/ty/issues/1736> because the
type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then
`C[set[set[set[T]]]]`, etc.):
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Protocol
class C[T](Protocol):
a: "C[set[T]]"
def takes_c(c: C[set[int]]) -> None: ...
def f(c: C[int]) -> None:
# The key thing is that we don't stack overflow while checking this.
# The cycle detection assumes compatibility when it detects potential
# infinite recursion between protocol specializations.
takes_c(c)
```
### Recursive legacy generic protocol
```py

View File

@ -204,9 +204,48 @@ fn definition_expression_type<'db>(
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;
/// A [`PairVisitor`] that is used in `has_relation_to` methods.
pub(crate) type HasRelationToVisitor<'db> =
CycleDetector<TypeRelation<'db>, (Type<'db>, Type<'db>, TypeRelation<'db>), ConstraintSet<'db>>;
/// Key type for the `has_relation_to` visitor.
///
/// For most type comparisons, we use the full `Type` as the key. However, for protocol-to-protocol
/// comparisons, we use the underlying `ClassLiteral` (ignoring specialization) to detect infinite
/// recursion that occurs with recursive generic protocols.
///
/// For example, with:
/// ```python
/// class C[T](Protocol):
/// a: 'C[set[T]]'
/// ```
///
/// Checking `C[set[int]] <: C[set[int]]` leads to checking `C[set[set[int]]] <: C[set[set[int]]]`,
/// then `C[set[set[set[int]]]] <: C[set[set[set[int]]]]`, etc. Each level has different type
/// specializations, so using full types as keys doesn't detect the cycle. By using `ClassLiteral`
/// as the key for protocol comparisons, we detect that we're comparing protocol `C` against itself
/// regardless of specialization.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum TypeRelationKey<'db> {
/// A regular type - used for most comparisons.
Type(Type<'db>),
/// A protocol class literal (without specialization) - used for protocol-to-protocol comparisons
/// to detect recursive generic protocols.
ProtocolClass(ClassLiteral<'db>),
}
impl<'db> From<Type<'db>> for TypeRelationKey<'db> {
fn from(ty: Type<'db>) -> Self {
TypeRelationKey::Type(ty)
}
}
/// A [`CycleDetector`] that is used in `has_relation_to` methods.
pub(crate) type HasRelationToVisitor<'db> = CycleDetector<
TypeRelation<'db>,
(
TypeRelationKey<'db>,
TypeRelationKey<'db>,
TypeRelation<'db>,
),
ConstraintSet<'db>,
>;
impl Default for HasRelationToVisitor<'_> {
fn default() -> Self {
@ -1973,7 +2012,7 @@ impl<'db> Type<'db> {
}
(Type::TypeAlias(self_alias), _) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self_alias.value_type(db).has_relation_to_impl(
db,
target,
@ -1986,7 +2025,7 @@ impl<'db> Type<'db> {
}
(_, Type::TypeAlias(target_alias)) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self.has_relation_to_impl(
db,
target_alias.value_type(db),
@ -2452,7 +2491,7 @@ impl<'db> Type<'db> {
) => ConstraintSet::from(false),
(Type::Callable(self_callable), Type::Callable(other_callable)) => relation_visitor
.visit((self, target, relation), || {
.visit((self.into(), target.into(), relation), || {
self_callable.has_relation_to_impl(
db,
other_callable,
@ -2464,7 +2503,7 @@ impl<'db> Type<'db> {
}),
(_, Type::Callable(other_callable)) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self.try_upcast_to_callable(db).when_some_and(|callables| {
callables.has_relation_to_impl(
db,
@ -2499,7 +2538,26 @@ impl<'db> Type<'db> {
}
(_, Type::ProtocolInstance(protocol)) => {
relation_visitor.visit((self, target, relation), || {
// For protocol-to-protocol comparisons, use ClassLiteral keys to detect
// infinite recursion with recursive generic protocols (e.g., `class C[T](Protocol): a: C[set[T]]`).
// When both types are protocols of the same class, the types may differ due to
// different specializations, but comparing them would lead to infinite recursion.
let (self_key, target_key) = if let Type::ProtocolInstance(self_protocol) = self {
// Both are protocol instances - try to use class literals as keys
// for detecting cycles in recursive generic protocols
match (self_protocol.class_literal(db), protocol.class_literal(db)) {
(Some(self_class), Some(target_class)) => (
TypeRelationKey::ProtocolClass(self_class),
TypeRelationKey::ProtocolClass(target_class),
),
// One or both are synthesized protocols - fall back to full types
_ => (self.into(), target.into()),
}
} else {
// Source is not a protocol - use full types
(self.into(), target.into())
};
relation_visitor.visit((self_key, target_key, relation), || {
self.satisfies_protocol(
db,
protocol,
@ -2515,7 +2573,7 @@ impl<'db> Type<'db> {
(Type::ProtocolInstance(_), _) => ConstraintSet::from(false),
(Type::TypedDict(self_typeddict), Type::TypedDict(other_typeddict)) => relation_visitor
.visit((self, target, relation), || {
.visit((self.into(), target.into(), relation), || {
self_typeddict.has_relation_to_impl(
db,
other_typeddict,
@ -2530,18 +2588,23 @@ impl<'db> Type<'db> {
// compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as
// long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all
// key types are a supertype of the extra items type?)
(Type::TypedDict(_), _) => relation_visitor.visit((self, target, relation), || {
KnownClass::Mapping
.to_specialized_instance(db, [KnownClass::Str.to_instance(db), Type::object()])
.has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
}),
(Type::TypedDict(_), _) => {
relation_visitor.visit((self.into(), target.into(), relation), || {
KnownClass::Mapping
.to_specialized_instance(
db,
[KnownClass::Str.to_instance(db), Type::object()],
)
.has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
})
}
// A non-`TypedDict` cannot subtype a `TypedDict`
(_, Type::TypedDict(_)) => ConstraintSet::from(false),
@ -2841,7 +2904,7 @@ impl<'db> Type<'db> {
// `bool` is a subtype of `int`, because `bool` subclasses `int`,
// which means that all instances of `bool` are also instances of `int`
(Type::NominalInstance(self_instance), Type::NominalInstance(target_instance)) => {
relation_visitor.visit((self, target, relation), || {
relation_visitor.visit((self.into(), target.into(), relation), || {
self_instance.has_relation_to_impl(
db,
target_instance,

View File

@ -659,6 +659,16 @@ impl<'db> ProtocolInstanceType<'db> {
}
}
/// If this is a class-based protocol, return its class literal (without specialization).
///
/// Returns `None` for synthesized protocols that don't correspond to a class definition.
pub(super) fn class_literal(self, db: &'db dyn Db) -> Option<ClassLiteral<'db>> {
match self.inner {
Protocol::FromClass(class) => Some(class.class_literal(db).0),
Protocol::Synthesized(_) => None,
}
}
/// Return the meta-type of this protocol-instance type.
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
match self.inner {