mirror of https://github.com/astral-sh/ruff
Fix stack overflow with recursive generic protocols (depth limit) (#21858)
## Summary This fixes https://github.com/astral-sh/ty/issues/1736 where recursive generic protocols with growing specializations caused a stack overflow. The issue occurred with protocols like: ```python class C[T](Protocol): a: 'C[set[T]]' ``` When checking `C[set[int]]` against e.g. `C[Unknown]`, member `a` requires checking `C[set[set[int]]]`, which requires `C[set[set[set[int]]]]`, etc. Each level has different type specializations, so the existing cycle detection (using full types as cache keys) didn't catch the infinite recursion. This fix adds a simple recursion depth limit (64) to the CycleDetector. When the depth exceeds the limit, we return the fallback value (assume compatible) to safely terminate the recursion. This is a bit of a blunt hammer, but it should be broadly effective to prevent stack overflow in any nested-relation case, and it's hard to imagine that non-recursive nested relation comparisons of depth > 64 exist much in the wild. ## Test Plan Added mdtest.
This commit is contained in:
parent
4e4d018344
commit
8727a7b179
|
|
@ -3010,6 +3010,31 @@ class Bar(Protocol[S]):
|
|||
z: S | Bar[S]
|
||||
```
|
||||
|
||||
### Recursive generic protocols with growing specializations
|
||||
|
||||
This snippet caused a stack overflow in <https://github.com/astral-sh/ty/issues/1736> because the
|
||||
type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then
|
||||
`C[set[set[set[T]]]]`, etc.):
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import Protocol
|
||||
|
||||
class C[T](Protocol):
|
||||
a: "C[set[T]]"
|
||||
|
||||
def takes_c(c: C[set[int]]) -> None: ...
|
||||
def f(c: C[int]) -> None:
|
||||
# The key thing is that we don't stack overflow while checking this.
|
||||
# The cycle detection assumes compatibility when it detects potential
|
||||
# infinite recursion between protocol specializations.
|
||||
takes_c(c)
|
||||
```
|
||||
|
||||
### Recursive legacy generic protocol
|
||||
|
||||
```py
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
//! of the Rust types implementing protocols also call `visitor.visit`. The best way to avoid this
|
||||
//! is to prefer always calling `visitor.visit` only in the main recursive method on `Type`.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::cmp::Eq;
|
||||
use std::hash::Hash;
|
||||
use std::marker::PhantomData;
|
||||
|
|
@ -29,6 +29,22 @@ use rustc_hash::FxHashMap;
|
|||
use crate::FxIndexSet;
|
||||
use crate::types::Type;
|
||||
|
||||
/// Maximum recursion depth for cycle detection.
|
||||
///
|
||||
/// This is a safety limit to prevent stack overflow when checking recursive generic protocols
|
||||
/// that create infinitely growing type specializations. For example:
|
||||
///
|
||||
/// ```python
|
||||
/// class C[T](Protocol):
|
||||
/// a: 'C[set[T]]'
|
||||
/// ```
|
||||
///
|
||||
/// When checking `C[set[int]]` against e.g. `C[Unknown]`, member `a` requires checking
|
||||
/// `C[set[set[int]]]`, which in turn requires checking `C[set[set[set[int]]]]`, etc. Each level
|
||||
/// creates a unique cache key, so the standard cycle detection doesn't catch it. The depth limit
|
||||
/// ensures we bail out before hitting a stack overflow.
|
||||
const MAX_RECURSION_DEPTH: u32 = 64;
|
||||
|
||||
pub(crate) type TypeTransformer<'db, Tag> = CycleDetector<Tag, Type<'db>, Type<'db>>;
|
||||
|
||||
impl<Tag> Default for TypeTransformer<'_, Tag> {
|
||||
|
|
@ -58,6 +74,10 @@ pub struct CycleDetector<Tag, T, R> {
|
|||
/// sort-of defeat the point of a cache if we did!)
|
||||
cache: RefCell<FxHashMap<T, R>>,
|
||||
|
||||
/// Current recursion depth. Used to prevent stack overflow if recursive generic types create
|
||||
/// infinitely growing type specializations that don't trigger exact-match cycle detection.
|
||||
depth: Cell<u32>,
|
||||
|
||||
fallback: R,
|
||||
|
||||
_tag: PhantomData<Tag>,
|
||||
|
|
@ -68,6 +88,7 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
|
|||
CycleDetector {
|
||||
seen: RefCell::new(FxIndexSet::default()),
|
||||
cache: RefCell::new(FxHashMap::default()),
|
||||
depth: Cell::new(0),
|
||||
fallback,
|
||||
_tag: PhantomData,
|
||||
}
|
||||
|
|
@ -83,7 +104,18 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
|
|||
return self.fallback.clone();
|
||||
}
|
||||
|
||||
// Check depth limit to prevent stack overflow from recursive generic types
|
||||
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
|
||||
let current_depth = self.depth.get();
|
||||
if current_depth >= MAX_RECURSION_DEPTH {
|
||||
self.seen.borrow_mut().pop();
|
||||
return self.fallback.clone();
|
||||
}
|
||||
self.depth.set(current_depth + 1);
|
||||
|
||||
let ret = func();
|
||||
|
||||
self.depth.set(current_depth);
|
||||
self.seen.borrow_mut().pop();
|
||||
self.cache.borrow_mut().insert(item, ret.clone());
|
||||
|
||||
|
|
@ -100,7 +132,18 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
|
|||
return Some(self.fallback.clone());
|
||||
}
|
||||
|
||||
// Check depth limit to prevent stack overflow from recursive generic protocols
|
||||
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
|
||||
let current_depth = self.depth.get();
|
||||
if current_depth >= MAX_RECURSION_DEPTH {
|
||||
self.seen.borrow_mut().pop();
|
||||
return Some(self.fallback.clone());
|
||||
}
|
||||
self.depth.set(current_depth + 1);
|
||||
|
||||
let ret = func()?;
|
||||
|
||||
self.depth.set(current_depth);
|
||||
self.seen.borrow_mut().pop();
|
||||
self.cache.borrow_mut().insert(item, ret.clone());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue