mirror of https://github.com/astral-sh/ruff
Fix stack overflow with recursive generic protocols
This fixes https://github.com/astral-sh/ty/issues/1736 where recursive generic protocols with growing specializations caused a stack overflow. The issue occurred with protocols like: ```python class C[T](Protocol): a: 'C[set[T]]' ``` When checking `C[set[int]]` against `C[Unknown]`, member `a` requires checking `C[set[set[int]]]`, which requires `C[set[set[set[int]]]]`, etc. Each level has different type specializations, so the existing cycle detection (using full types as cache keys) didn't catch the infinite recursion. The fix introduces `TypeRelationKey`, an enum that can be either a full `Type` or a `ClassLiteral` (protocol class without specialization). For protocol-to-protocol comparisons, we use `ClassLiteral` keys, which detects when we're comparing the same protocol class regardless of specialization. When a cycle is detected, we return the fallback value (assume compatible) to safely terminate the recursion.
This commit is contained in:
parent
4e67a219bb
commit
c88e1e40ab
|
|
@ -3010,6 +3010,31 @@ class Bar(Protocol[S]):
|
|||
z: S | Bar[S]
|
||||
```
|
||||
|
||||
### Recursive generic protocols with growing specializations
|
||||
|
||||
This snippet caused a stack overflow in <https://github.com/astral-sh/ty/issues/1736> because the
|
||||
type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then
|
||||
`C[set[set[set[T]]]]`, etc.):
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import Protocol
|
||||
|
||||
class C[T](Protocol):
|
||||
a: "C[set[T]]"
|
||||
|
||||
def takes_c(c: C[set[int]]) -> None: ...
|
||||
def f(c: C[int]) -> None:
|
||||
# The key thing is that we don't stack overflow while checking this.
|
||||
# The cycle detection assumes compatibility when it detects potential
|
||||
# infinite recursion between protocol specializations.
|
||||
takes_c(c)
|
||||
```
|
||||
|
||||
### Recursive legacy generic protocol
|
||||
|
||||
```py
|
||||
|
|
|
|||
|
|
@ -204,9 +204,48 @@ fn definition_expression_type<'db>(
|
|||
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
|
||||
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;
|
||||
|
||||
/// A [`PairVisitor`] that is used in `has_relation_to` methods.
|
||||
pub(crate) type HasRelationToVisitor<'db> =
|
||||
CycleDetector<TypeRelation<'db>, (Type<'db>, Type<'db>, TypeRelation<'db>), ConstraintSet<'db>>;
|
||||
/// Key type for the `has_relation_to` visitor.
|
||||
///
|
||||
/// For most type comparisons, we use the full `Type` as the key. However, for protocol-to-protocol
|
||||
/// comparisons, we use the underlying `ClassLiteral` (ignoring specialization) to detect infinite
|
||||
/// recursion that occurs with recursive generic protocols.
|
||||
///
|
||||
/// For example, with:
|
||||
/// ```python
|
||||
/// class C[T](Protocol):
|
||||
/// a: 'C[set[T]]'
|
||||
/// ```
|
||||
///
|
||||
/// Checking `C[set[int]] <: C[set[int]]` leads to checking `C[set[set[int]]] <: C[set[set[int]]]`,
|
||||
/// then `C[set[set[set[int]]]] <: C[set[set[set[int]]]]`, etc. Each level has different type
|
||||
/// specializations, so using full types as keys doesn't detect the cycle. By using `ClassLiteral`
|
||||
/// as the key for protocol comparisons, we detect that we're comparing protocol `C` against itself
|
||||
/// regardless of specialization.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub(crate) enum TypeRelationKey<'db> {
|
||||
/// A regular type - used for most comparisons.
|
||||
Type(Type<'db>),
|
||||
/// A protocol class literal (without specialization) - used for protocol-to-protocol comparisons
|
||||
/// to detect recursive generic protocols.
|
||||
ProtocolClass(ClassLiteral<'db>),
|
||||
}
|
||||
|
||||
impl<'db> From<Type<'db>> for TypeRelationKey<'db> {
|
||||
fn from(ty: Type<'db>) -> Self {
|
||||
TypeRelationKey::Type(ty)
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`CycleDetector`] that is used in `has_relation_to` methods.
|
||||
pub(crate) type HasRelationToVisitor<'db> = CycleDetector<
|
||||
TypeRelation<'db>,
|
||||
(
|
||||
TypeRelationKey<'db>,
|
||||
TypeRelationKey<'db>,
|
||||
TypeRelation<'db>,
|
||||
),
|
||||
ConstraintSet<'db>,
|
||||
>;
|
||||
|
||||
impl Default for HasRelationToVisitor<'_> {
|
||||
fn default() -> Self {
|
||||
|
|
@ -1973,7 +2012,7 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
|
||||
(Type::TypeAlias(self_alias), _) => {
|
||||
relation_visitor.visit((self, target, relation), || {
|
||||
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||
self_alias.value_type(db).has_relation_to_impl(
|
||||
db,
|
||||
target,
|
||||
|
|
@ -1986,7 +2025,7 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
|
||||
(_, Type::TypeAlias(target_alias)) => {
|
||||
relation_visitor.visit((self, target, relation), || {
|
||||
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||
self.has_relation_to_impl(
|
||||
db,
|
||||
target_alias.value_type(db),
|
||||
|
|
@ -2452,7 +2491,7 @@ impl<'db> Type<'db> {
|
|||
) => ConstraintSet::from(false),
|
||||
|
||||
(Type::Callable(self_callable), Type::Callable(other_callable)) => relation_visitor
|
||||
.visit((self, target, relation), || {
|
||||
.visit((self.into(), target.into(), relation), || {
|
||||
self_callable.has_relation_to_impl(
|
||||
db,
|
||||
other_callable,
|
||||
|
|
@ -2464,7 +2503,7 @@ impl<'db> Type<'db> {
|
|||
}),
|
||||
|
||||
(_, Type::Callable(other_callable)) => {
|
||||
relation_visitor.visit((self, target, relation), || {
|
||||
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||
self.try_upcast_to_callable(db).when_some_and(|callables| {
|
||||
callables.has_relation_to_impl(
|
||||
db,
|
||||
|
|
@ -2499,7 +2538,26 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
|
||||
(_, Type::ProtocolInstance(protocol)) => {
|
||||
relation_visitor.visit((self, target, relation), || {
|
||||
// For protocol-to-protocol comparisons, use ClassLiteral keys to detect
|
||||
// infinite recursion with recursive generic protocols (e.g., `class C[T](Protocol): a: C[set[T]]`).
|
||||
// When both types are protocols of the same class, the types may differ due to
|
||||
// different specializations, but comparing them would lead to infinite recursion.
|
||||
let (self_key, target_key) = if let Type::ProtocolInstance(self_protocol) = self {
|
||||
// Both are protocol instances - try to use class literals as keys
|
||||
// for detecting cycles in recursive generic protocols
|
||||
match (self_protocol.class_literal(db), protocol.class_literal(db)) {
|
||||
(Some(self_class), Some(target_class)) => (
|
||||
TypeRelationKey::ProtocolClass(self_class),
|
||||
TypeRelationKey::ProtocolClass(target_class),
|
||||
),
|
||||
// One or both are synthesized protocols - fall back to full types
|
||||
_ => (self.into(), target.into()),
|
||||
}
|
||||
} else {
|
||||
// Source is not a protocol - use full types
|
||||
(self.into(), target.into())
|
||||
};
|
||||
relation_visitor.visit((self_key, target_key, relation), || {
|
||||
self.satisfies_protocol(
|
||||
db,
|
||||
protocol,
|
||||
|
|
@ -2515,7 +2573,7 @@ impl<'db> Type<'db> {
|
|||
(Type::ProtocolInstance(_), _) => ConstraintSet::from(false),
|
||||
|
||||
(Type::TypedDict(self_typeddict), Type::TypedDict(other_typeddict)) => relation_visitor
|
||||
.visit((self, target, relation), || {
|
||||
.visit((self.into(), target.into(), relation), || {
|
||||
self_typeddict.has_relation_to_impl(
|
||||
db,
|
||||
other_typeddict,
|
||||
|
|
@ -2530,18 +2588,23 @@ impl<'db> Type<'db> {
|
|||
// compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as
|
||||
// long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all
|
||||
// key types are a supertype of the extra items type?)
|
||||
(Type::TypedDict(_), _) => relation_visitor.visit((self, target, relation), || {
|
||||
KnownClass::Mapping
|
||||
.to_specialized_instance(db, [KnownClass::Str.to_instance(db), Type::object()])
|
||||
.has_relation_to_impl(
|
||||
db,
|
||||
target,
|
||||
inferable,
|
||||
relation,
|
||||
relation_visitor,
|
||||
disjointness_visitor,
|
||||
)
|
||||
}),
|
||||
(Type::TypedDict(_), _) => {
|
||||
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||
KnownClass::Mapping
|
||||
.to_specialized_instance(
|
||||
db,
|
||||
[KnownClass::Str.to_instance(db), Type::object()],
|
||||
)
|
||||
.has_relation_to_impl(
|
||||
db,
|
||||
target,
|
||||
inferable,
|
||||
relation,
|
||||
relation_visitor,
|
||||
disjointness_visitor,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// A non-`TypedDict` cannot subtype a `TypedDict`
|
||||
(_, Type::TypedDict(_)) => ConstraintSet::from(false),
|
||||
|
|
@ -2841,7 +2904,7 @@ impl<'db> Type<'db> {
|
|||
// `bool` is a subtype of `int`, because `bool` subclasses `int`,
|
||||
// which means that all instances of `bool` are also instances of `int`
|
||||
(Type::NominalInstance(self_instance), Type::NominalInstance(target_instance)) => {
|
||||
relation_visitor.visit((self, target, relation), || {
|
||||
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||
self_instance.has_relation_to_impl(
|
||||
db,
|
||||
target_instance,
|
||||
|
|
|
|||
|
|
@ -659,6 +659,16 @@ impl<'db> ProtocolInstanceType<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// If this is a class-based protocol, return its class literal (without specialization).
|
||||
///
|
||||
/// Returns `None` for synthesized protocols that don't correspond to a class definition.
|
||||
pub(super) fn class_literal(self, db: &'db dyn Db) -> Option<ClassLiteral<'db>> {
|
||||
match self.inner {
|
||||
Protocol::FromClass(class) => Some(class.class_literal(db).0),
|
||||
Protocol::Synthesized(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the meta-type of this protocol-instance type.
|
||||
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||
match self.inner {
|
||||
|
|
|
|||
Loading…
Reference in New Issue