diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index 28069bd07c..40b180fb35 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -3010,6 +3010,31 @@ class Bar(Protocol[S]): z: S | Bar[S] ``` +### Recursive generic protocols with growing specializations + +This snippet caused a stack overflow in because the +type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then +`C[set[set[set[T]]]]`, etc.): + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Protocol + +class C[T](Protocol): + a: "C[set[T]]" + +def takes_c(c: C[set[int]]) -> None: ... +def f(c: C[int]) -> None: + # The key thing is that we don't stack overflow while checking this. + # The cycle detection assumes compatibility when it detects potential + # infinite recursion between protocol specializations. + takes_c(c) +``` + ### Recursive legacy generic protocol ```py diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index dfc932bc66..bd62d79eca 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -204,9 +204,48 @@ fn definition_expression_type<'db>( /// A [`TypeTransformer`] that is used in `apply_type_mapping` methods. pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>; -/// A [`PairVisitor`] that is used in `has_relation_to` methods. -pub(crate) type HasRelationToVisitor<'db> = - CycleDetector, (Type<'db>, Type<'db>, TypeRelation<'db>), ConstraintSet<'db>>; +/// Key type for the `has_relation_to` visitor. +/// +/// For most type comparisons, we use the full `Type` as the key. However, for protocol-to-protocol +/// comparisons, we use the underlying `ClassLiteral` (ignoring specialization) to detect infinite +/// recursion that occurs with recursive generic protocols. +/// +/// For example, with: +/// ```python +/// class C[T](Protocol): +/// a: 'C[set[T]]' +/// ``` +/// +/// Checking `C[set[int]] <: C[set[int]]` leads to checking `C[set[set[int]]] <: C[set[set[int]]]`, +/// then `C[set[set[set[int]]]] <: C[set[set[set[int]]]]`, etc. Each level has different type +/// specializations, so using full types as keys doesn't detect the cycle. By using `ClassLiteral` +/// as the key for protocol comparisons, we detect that we're comparing protocol `C` against itself +/// regardless of specialization. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) enum TypeRelationKey<'db> { + /// A regular type - used for most comparisons. + Type(Type<'db>), + /// A protocol class literal (without specialization) - used for protocol-to-protocol comparisons + /// to detect recursive generic protocols. + ProtocolClass(ClassLiteral<'db>), +} + +impl<'db> From> for TypeRelationKey<'db> { + fn from(ty: Type<'db>) -> Self { + TypeRelationKey::Type(ty) + } +} + +/// A [`CycleDetector`] that is used in `has_relation_to` methods. +pub(crate) type HasRelationToVisitor<'db> = CycleDetector< + TypeRelation<'db>, + ( + TypeRelationKey<'db>, + TypeRelationKey<'db>, + TypeRelation<'db>, + ), + ConstraintSet<'db>, +>; impl Default for HasRelationToVisitor<'_> { fn default() -> Self { @@ -1973,7 +2012,7 @@ impl<'db> Type<'db> { } (Type::TypeAlias(self_alias), _) => { - relation_visitor.visit((self, target, relation), || { + relation_visitor.visit((self.into(), target.into(), relation), || { self_alias.value_type(db).has_relation_to_impl( db, target, @@ -1986,7 +2025,7 @@ impl<'db> Type<'db> { } (_, Type::TypeAlias(target_alias)) => { - relation_visitor.visit((self, target, relation), || { + relation_visitor.visit((self.into(), target.into(), relation), || { self.has_relation_to_impl( db, target_alias.value_type(db), @@ -2452,7 +2491,7 @@ impl<'db> Type<'db> { ) => ConstraintSet::from(false), (Type::Callable(self_callable), Type::Callable(other_callable)) => relation_visitor - .visit((self, target, relation), || { + .visit((self.into(), target.into(), relation), || { self_callable.has_relation_to_impl( db, other_callable, @@ -2464,7 +2503,7 @@ impl<'db> Type<'db> { }), (_, Type::Callable(other_callable)) => { - relation_visitor.visit((self, target, relation), || { + relation_visitor.visit((self.into(), target.into(), relation), || { self.try_upcast_to_callable(db).when_some_and(|callables| { callables.has_relation_to_impl( db, @@ -2499,7 +2538,26 @@ impl<'db> Type<'db> { } (_, Type::ProtocolInstance(protocol)) => { - relation_visitor.visit((self, target, relation), || { + // For protocol-to-protocol comparisons, use ClassLiteral keys to detect + // infinite recursion with recursive generic protocols (e.g., `class C[T](Protocol): a: C[set[T]]`). + // When both types are protocols of the same class, the types may differ due to + // different specializations, but comparing them would lead to infinite recursion. + let (self_key, target_key) = if let Type::ProtocolInstance(self_protocol) = self { + // Both are protocol instances - try to use class literals as keys + // for detecting cycles in recursive generic protocols + match (self_protocol.class_literal(db), protocol.class_literal(db)) { + (Some(self_class), Some(target_class)) => ( + TypeRelationKey::ProtocolClass(self_class), + TypeRelationKey::ProtocolClass(target_class), + ), + // One or both are synthesized protocols - fall back to full types + _ => (self.into(), target.into()), + } + } else { + // Source is not a protocol - use full types + (self.into(), target.into()) + }; + relation_visitor.visit((self_key, target_key, relation), || { self.satisfies_protocol( db, protocol, @@ -2515,7 +2573,7 @@ impl<'db> Type<'db> { (Type::ProtocolInstance(_), _) => ConstraintSet::from(false), (Type::TypedDict(self_typeddict), Type::TypedDict(other_typeddict)) => relation_visitor - .visit((self, target, relation), || { + .visit((self.into(), target.into(), relation), || { self_typeddict.has_relation_to_impl( db, other_typeddict, @@ -2530,18 +2588,23 @@ impl<'db> Type<'db> { // compatible `Mapping`s. `extra_items` could also allow for some assignments to `dict`, as // long as `total=False`. (But then again, does anyone want a non-total `TypedDict` where all // key types are a supertype of the extra items type?) - (Type::TypedDict(_), _) => relation_visitor.visit((self, target, relation), || { - KnownClass::Mapping - .to_specialized_instance(db, [KnownClass::Str.to_instance(db), Type::object()]) - .has_relation_to_impl( - db, - target, - inferable, - relation, - relation_visitor, - disjointness_visitor, - ) - }), + (Type::TypedDict(_), _) => { + relation_visitor.visit((self.into(), target.into(), relation), || { + KnownClass::Mapping + .to_specialized_instance( + db, + [KnownClass::Str.to_instance(db), Type::object()], + ) + .has_relation_to_impl( + db, + target, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + }) + } // A non-`TypedDict` cannot subtype a `TypedDict` (_, Type::TypedDict(_)) => ConstraintSet::from(false), @@ -2841,7 +2904,7 @@ impl<'db> Type<'db> { // `bool` is a subtype of `int`, because `bool` subclasses `int`, // which means that all instances of `bool` are also instances of `int` (Type::NominalInstance(self_instance), Type::NominalInstance(target_instance)) => { - relation_visitor.visit((self, target, relation), || { + relation_visitor.visit((self.into(), target.into(), relation), || { self_instance.has_relation_to_impl( db, target_instance, diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index fb53f10ef4..8dcebf2741 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -659,6 +659,16 @@ impl<'db> ProtocolInstanceType<'db> { } } + /// If this is a class-based protocol, return its class literal (without specialization). + /// + /// Returns `None` for synthesized protocols that don't correspond to a class definition. + pub(super) fn class_literal(self, db: &'db dyn Db) -> Option> { + match self.inner { + Protocol::FromClass(class) => Some(class.class_literal(db).0), + Protocol::Synthesized(_) => None, + } + } + /// Return the meta-type of this protocol-instance type. pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> { match self.inner {