mirror of https://github.com/astral-sh/ruff
Fix stack overflow with recursive generic protocols
This fixes https://github.com/astral-sh/ty/issues/1736 where recursive generic protocols with growing specializations caused a stack overflow. The issue occurred with protocols like: ```python class C[T](Protocol): a: 'C[set[T]]' ``` When checking `C[set[int]]` against `C[Unknown]`, member `a` requires checking `C[set[set[int]]]`, which requires `C[set[set[set[int]]]]`, etc. Each level has different type specializations, so the existing cycle detection (using full types as cache keys) didn't catch the infinite recursion. The fix introduces `TypeRelationKey`, an enum that can be either a full `Type` or a `ClassLiteral` (protocol class without specialization). For protocol-to-protocol comparisons, we use `ClassLiteral` keys, which detects when we're comparing the same protocol class regardless of specialization. When a cycle is detected, we return the fallback value (assume compatible) to safely terminate the recursion.
This commit is contained in:
parent
4e67a219bb
commit
c88e1e40ab
|
|
@ -3010,6 +3010,31 @@ class Bar(Protocol[S]):
|
||||||
z: S | Bar[S]
|
z: S | Bar[S]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Recursive generic protocols with growing specializations
|
||||||
|
|
||||||
|
This snippet caused a stack overflow in <https://github.com/astral-sh/ty/issues/1736> because the
|
||||||
|
type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then
|
||||||
|
`C[set[set[set[T]]]]`, etc.):
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[environment]
|
||||||
|
python-version = "3.12"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class C[T](Protocol):
|
||||||
|
a: "C[set[T]]"
|
||||||
|
|
||||||
|
def takes_c(c: C[set[int]]) -> None: ...
|
||||||
|
def f(c: C[int]) -> None:
|
||||||
|
# The key thing is that we don't stack overflow while checking this.
|
||||||
|
# The cycle detection assumes compatibility when it detects potential
|
||||||
|
# infinite recursion between protocol specializations.
|
||||||
|
takes_c(c)
|
||||||
|
```
|
||||||
|
|
||||||
### Recursive legacy generic protocol
|
### Recursive legacy generic protocol
|
||||||
|
|
||||||
```py
|
```py
|
||||||
|
|
|
||||||
|
|
@ -204,9 +204,48 @@ fn definition_expression_type<'db>(
|
||||||
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
|
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
|
||||||
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;
|
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;
|
||||||
|
|
||||||
/// A [`PairVisitor`] that is used in `has_relation_to` methods.
|
/// Key type for the `has_relation_to` visitor.
|
||||||
pub(crate) type HasRelationToVisitor<'db> =
|
///
|
||||||
CycleDetector<TypeRelation<'db>, (Type<'db>, Type<'db>, TypeRelation<'db>), ConstraintSet<'db>>;
|
/// For most type comparisons, we use the full `Type` as the key. However, for protocol-to-protocol
|
||||||
|
/// comparisons, we use the underlying `ClassLiteral` (ignoring specialization) to detect infinite
|
||||||
|
/// recursion that occurs with recursive generic protocols.
|
||||||
|
///
|
||||||
|
/// For example, with:
|
||||||
|
/// ```python
|
||||||
|
/// class C[T](Protocol):
|
||||||
|
/// a: 'C[set[T]]'
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Checking `C[set[int]] <: C[set[int]]` leads to checking `C[set[set[int]]] <: C[set[set[int]]]`,
|
||||||
|
/// then `C[set[set[set[int]]]] <: C[set[set[set[int]]]]`, etc. Each level has different type
|
||||||
|
/// specializations, so using full types as keys doesn't detect the cycle. By using `ClassLiteral`
|
||||||
|
/// as the key for protocol comparisons, we detect that we're comparing protocol `C` against itself
|
||||||
|
/// regardless of specialization.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub(crate) enum TypeRelationKey<'db> {
|
||||||
|
/// A regular type - used for most comparisons.
|
||||||
|
Type(Type<'db>),
|
||||||
|
/// A protocol class literal (without specialization) - used for protocol-to-protocol comparisons
|
||||||
|
/// to detect recursive generic protocols.
|
||||||
|
ProtocolClass(ClassLiteral<'db>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'db> From<Type<'db>> for TypeRelationKey<'db> {
|
||||||
|
fn from(ty: Type<'db>) -> Self {
|
||||||
|
TypeRelationKey::Type(ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`CycleDetector`] that is used in `has_relation_to` methods.
|
||||||
|
pub(crate) type HasRelationToVisitor<'db> = CycleDetector<
|
||||||
|
TypeRelation<'db>,
|
||||||
|
(
|
||||||
|
TypeRelationKey<'db>,
|
||||||
|
TypeRelationKey<'db>,
|
||||||
|
TypeRelation<'db>,
|
||||||
|
),
|
||||||
|
ConstraintSet<'db>,
|
||||||
|
>;
|
||||||
|
|
||||||
impl Default for HasRelationToVisitor<'_> {
|
impl Default for HasRelationToVisitor<'_> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
|
|
@ -1973,7 +2012,7 @@ impl<'db> Type<'db> {
|
||||||
}
|
}
|
||||||
|
|
||||||
(Type::TypeAlias(self_alias), _) => {
|
(Type::TypeAlias(self_alias), _) => {
|
||||||
relation_visitor.visit((self, target, relation), || {
|
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||||
self_alias.value_type(db).has_relation_to_impl(
|
self_alias.value_type(db).has_relation_to_impl(
|
||||||
db,
|
db,
|
||||||
target,
|
target,
|
||||||
|
|
@ -1986,7 +2025,7 @@ impl<'db> Type<'db> {
|
||||||
}
|
}
|
||||||
|
|
||||||
(_, Type::TypeAlias(target_alias)) => {
|
(_, Type::TypeAlias(target_alias)) => {
|
||||||
relation_visitor.visit((self, target, relation), || {
|
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||||
self.has_relation_to_impl(
|
self.has_relation_to_impl(
|
||||||
db,
|
db,
|
||||||
target_alias.value_type(db),
|
target_alias.value_type(db),
|
||||||
|
|
@ -2452,7 +2491,7 @@ impl<'db> Type<'db> {
|
||||||
) => ConstraintSet::from(false),
|
) => ConstraintSet::from(false),
|
||||||
|
|
||||||
(Type::Callable(self_callable), Type::Callable(other_callable)) => relation_visitor
|
(Type::Callable(self_callable), Type::Callable(other_callable)) => relation_visitor
|
||||||
.visit((self, target, relation), || {
|
.visit((self.into(), target.into(), relation), || {
|
||||||
self_callable.has_relation_to_impl(
|
self_callable.has_relation_to_impl(
|
||||||
db,
|
db,
|
||||||
other_callable,
|
other_callable,
|
||||||
|
|
@ -2464,7 +2503,7 @@ impl<'db> Type<'db> {
|
||||||
}),
|
}),
|
||||||
|
|
||||||
(_, Type::Callable(other_callable)) => {
|
(_, Type::Callable(other_callable)) => {
|
||||||
relation_visitor.visit((self, target, relation), || {
|
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||||
self.try_upcast_to_callable(db).when_some_and(|callables| {
|
self.try_upcast_to_callable(db).when_some_and(|callables| {
|
||||||
callables.has_relation_to_impl(
|
callables.has_relation_to_impl(
|
||||||
db,
|
db,
|
||||||
|
|
@ -2499,7 +2538,26 @@ impl<'db> Type<'db> {
|
||||||
}
|
}
|
||||||
|
|
||||||
(_, Type::ProtocolInstance(protocol)) => {
|
(_, Type::ProtocolInstance(protocol)) => {
|
||||||
relation_visitor.visit((self, target, relation), || {
|
// For protocol-to-protocol comparisons, use ClassLiteral keys to detect
|
||||||
|
// infinite recursion with recursive generic protocols (e.g., `class C[T](Protocol): a: C[set[T]]`).
|
||||||
|
// When both types are protocols of the same class, the types may differ due to
|
||||||
|
// different specializations, but comparing them would lead to infinite recursion.
|
||||||
|
let (self_key, target_key) = if let Type::ProtocolInstance(self_protocol) = self {
|
||||||
|
// Both are protocol instances - try to use class literals as keys
|
||||||
|
// for detecting cycles in recursive generic protocols
|
||||||
|
match (self_protocol.class_literal(db), protocol.class_literal(db)) {
|
||||||
|
(Some(self_class), Some(target_class)) => (
|
||||||
|
TypeRelationKey::ProtocolClass(self_class),
|
||||||
|
TypeRelationKey::ProtocolClass(target_class),
|
||||||
|
),
|
||||||
|
// One or both are synthesized protocols - fall back to full types
|
||||||
|
_ => (self.into(), target.into()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Source is not a protocol - use full types
|
||||||
|
(self.into(), target.into())
|
||||||
|
};
|
||||||
|
relation_visitor.visit((self_key, target_key, relation), || {
|
||||||
self.satisfies_protocol(
|
self.satisfies_protocol(
|
||||||
db,
|
db,
|
||||||
protocol,
|
protocol,
|
||||||
|
|
@ -2515,7 +2573,7 @@ impl<'db> Type<'db> {
|
||||||
(Type::ProtocolInstance(_), _) => ConstraintSet::from(false),
|
(Type::ProtocolInstance(_), _) => ConstraintSet::from(false),
|
||||||
|
|
||||||
(Type::TypedDict(self_typeddict), Type::TypedDict(other_typeddict)) => relation_visitor
|
(Type::TypedDict(self_typeddict), Type::TypedDict(other_typeddict)) => relation_visitor
|
||||||
.visit((self, target, relation), || {
|
.visit((self.into(), target.into(), relation), || {
|
||||||
self_typeddict.has_relation_to_impl(
|
self_typeddict.has_relation_to_impl(
|
||||||
db,
|
db,
|
||||||
other_typeddict,
|
other_typeddict,
|
||||||
|
|
@ -2530,18 +2588,23 @@ impl<'db> Type<'db> {
|
||||||
// compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as
|
// compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as
|
||||||
// long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all
|
// long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all
|
||||||
// key types are a supertype of the extra items type?)
|
// key types are a supertype of the extra items type?)
|
||||||
(Type::TypedDict(_), _) => relation_visitor.visit((self, target, relation), || {
|
(Type::TypedDict(_), _) => {
|
||||||
KnownClass::Mapping
|
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||||
.to_specialized_instance(db, [KnownClass::Str.to_instance(db), Type::object()])
|
KnownClass::Mapping
|
||||||
.has_relation_to_impl(
|
.to_specialized_instance(
|
||||||
db,
|
db,
|
||||||
target,
|
[KnownClass::Str.to_instance(db), Type::object()],
|
||||||
inferable,
|
)
|
||||||
relation,
|
.has_relation_to_impl(
|
||||||
relation_visitor,
|
db,
|
||||||
disjointness_visitor,
|
target,
|
||||||
)
|
inferable,
|
||||||
}),
|
relation,
|
||||||
|
relation_visitor,
|
||||||
|
disjointness_visitor,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// A non-`TypedDict` cannot subtype a `TypedDict`
|
// A non-`TypedDict` cannot subtype a `TypedDict`
|
||||||
(_, Type::TypedDict(_)) => ConstraintSet::from(false),
|
(_, Type::TypedDict(_)) => ConstraintSet::from(false),
|
||||||
|
|
@ -2841,7 +2904,7 @@ impl<'db> Type<'db> {
|
||||||
// `bool` is a subtype of `int`, because `bool` subclasses `int`,
|
// `bool` is a subtype of `int`, because `bool` subclasses `int`,
|
||||||
// which means that all instances of `bool` are also instances of `int`
|
// which means that all instances of `bool` are also instances of `int`
|
||||||
(Type::NominalInstance(self_instance), Type::NominalInstance(target_instance)) => {
|
(Type::NominalInstance(self_instance), Type::NominalInstance(target_instance)) => {
|
||||||
relation_visitor.visit((self, target, relation), || {
|
relation_visitor.visit((self.into(), target.into(), relation), || {
|
||||||
self_instance.has_relation_to_impl(
|
self_instance.has_relation_to_impl(
|
||||||
db,
|
db,
|
||||||
target_instance,
|
target_instance,
|
||||||
|
|
|
||||||
|
|
@ -659,6 +659,16 @@ impl<'db> ProtocolInstanceType<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If this is a class-based protocol, return its class literal (without specialization).
|
||||||
|
///
|
||||||
|
/// Returns `None` for synthesized protocols that don't correspond to a class definition.
|
||||||
|
pub(super) fn class_literal(self, db: &'db dyn Db) -> Option<ClassLiteral<'db>> {
|
||||||
|
match self.inner {
|
||||||
|
Protocol::FromClass(class) => Some(class.class_literal(db).0),
|
||||||
|
Protocol::Synthesized(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Return the meta-type of this protocol-instance type.
|
/// Return the meta-type of this protocol-instance type.
|
||||||
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
|
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||||
match self.inner {
|
match self.inner {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue