Fix stack overflow with recursive generic protocols (depth limit) (#21858)

## Summary

This fixes https://github.com/astral-sh/ty/issues/1736 where recursive
generic protocols with growing specializations caused a stack overflow.

The issue occurred with protocols like:
```python
class C[T](Protocol):
    a: 'C[set[T]]'
```

When checking `C[set[int]]` against e.g. `C[Unknown]`, member `a`
requires checking `C[set[set[int]]]`, which requires
`C[set[set[set[int]]]]`, etc. Each level has different type
specializations, so the existing cycle detection (using full types as
cache keys) didn't catch the infinite recursion.

This fix adds a simple recursion depth limit (64) to the CycleDetector.
When the depth exceeds the limit, we return the fallback value (assume
compatible) to safely terminate the recursion.

This is a bit of a blunt hammer, but it should be broadly effective to
prevent stack overflow in any nested-relation case, and it's hard to
imagine that non-recursive nested relation comparisons of depth > 64
exist much in the wild.

## Test Plan

Added mdtest.
This commit is contained in:
Carl Meyer 2025-12-09 09:05:18 -08:00 committed by GitHub
parent 4e4d018344
commit 8727a7b179
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 1 deletions

View File

@ -3010,6 +3010,31 @@ class Bar(Protocol[S]):
z: S | Bar[S] z: S | Bar[S]
``` ```
### Recursive generic protocols with growing specializations
This snippet caused a stack overflow in <https://github.com/astral-sh/ty/issues/1736> because the
type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then
`C[set[set[set[T]]]]`, etc.):
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Protocol
class C[T](Protocol):
a: "C[set[T]]"
def takes_c(c: C[set[int]]) -> None: ...
def f(c: C[int]) -> None:
# The key thing is that we don't stack overflow while checking this.
# The cycle detection assumes compatibility when it detects potential
# infinite recursion between protocol specializations.
takes_c(c)
```
### Recursive legacy generic protocol ### Recursive legacy generic protocol
```py ```py

View File

@ -19,7 +19,7 @@
//! of the Rust types implementing protocols also call `visitor.visit`. The best way to avoid this //! of the Rust types implementing protocols also call `visitor.visit`. The best way to avoid this
//! is to prefer always calling `visitor.visit` only in the main recursive method on `Type`. //! is to prefer always calling `visitor.visit` only in the main recursive method on `Type`.
use std::cell::RefCell; use std::cell::{Cell, RefCell};
use std::cmp::Eq; use std::cmp::Eq;
use std::hash::Hash; use std::hash::Hash;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -29,6 +29,22 @@ use rustc_hash::FxHashMap;
use crate::FxIndexSet; use crate::FxIndexSet;
use crate::types::Type; use crate::types::Type;
/// Maximum recursion depth for cycle detection.
///
/// This is a safety limit to prevent stack overflow when checking recursive generic protocols
/// that create infinitely growing type specializations. For example:
///
/// ```python
/// class C[T](Protocol):
/// a: 'C[set[T]]'
/// ```
///
/// When checking `C[set[int]]` against e.g. `C[Unknown]`, member `a` requires checking
/// `C[set[set[int]]]`, which in turn requires checking `C[set[set[set[int]]]]`, etc. Each level
/// creates a unique cache key, so the standard cycle detection doesn't catch it. The depth limit
/// ensures we bail out before hitting a stack overflow.
const MAX_RECURSION_DEPTH: u32 = 64;
pub(crate) type TypeTransformer<'db, Tag> = CycleDetector<Tag, Type<'db>, Type<'db>>; pub(crate) type TypeTransformer<'db, Tag> = CycleDetector<Tag, Type<'db>, Type<'db>>;
impl<Tag> Default for TypeTransformer<'_, Tag> { impl<Tag> Default for TypeTransformer<'_, Tag> {
@ -58,6 +74,10 @@ pub struct CycleDetector<Tag, T, R> {
/// sort-of defeat the point of a cache if we did!) /// sort-of defeat the point of a cache if we did!)
cache: RefCell<FxHashMap<T, R>>, cache: RefCell<FxHashMap<T, R>>,
/// Current recursion depth. Used to prevent stack overflow if recursive generic types create
/// infinitely growing type specializations that don't trigger exact-match cycle detection.
depth: Cell<u32>,
fallback: R, fallback: R,
_tag: PhantomData<Tag>, _tag: PhantomData<Tag>,
@ -68,6 +88,7 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
CycleDetector { CycleDetector {
seen: RefCell::new(FxIndexSet::default()), seen: RefCell::new(FxIndexSet::default()),
cache: RefCell::new(FxHashMap::default()), cache: RefCell::new(FxHashMap::default()),
depth: Cell::new(0),
fallback, fallback,
_tag: PhantomData, _tag: PhantomData,
} }
@ -83,7 +104,18 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
return self.fallback.clone(); return self.fallback.clone();
} }
// Check depth limit to prevent stack overflow from recursive generic types
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
self.seen.borrow_mut().pop();
return self.fallback.clone();
}
self.depth.set(current_depth + 1);
let ret = func(); let ret = func();
self.depth.set(current_depth);
self.seen.borrow_mut().pop(); self.seen.borrow_mut().pop();
self.cache.borrow_mut().insert(item, ret.clone()); self.cache.borrow_mut().insert(item, ret.clone());
@ -100,7 +132,18 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
return Some(self.fallback.clone()); return Some(self.fallback.clone());
} }
// Check depth limit to prevent stack overflow from recursive generic protocols
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
self.seen.borrow_mut().pop();
return Some(self.fallback.clone());
}
self.depth.set(current_depth + 1);
let ret = func()?; let ret = func()?;
self.depth.set(current_depth);
self.seen.borrow_mut().pop(); self.seen.borrow_mut().pop();
self.cache.borrow_mut().insert(item, ret.clone()); self.cache.borrow_mut().insert(item, ret.clone());