mirror of https://github.com/astral-sh/ruff
Fix stack overflow with recursive generic protocols (depth limit) (#21858)
## Summary This fixes https://github.com/astral-sh/ty/issues/1736 where recursive generic protocols with growing specializations caused a stack overflow. The issue occurred with protocols like: ```python class C[T](Protocol): a: 'C[set[T]]' ``` When checking `C[set[int]]` against e.g. `C[Unknown]`, member `a` requires checking `C[set[set[int]]]`, which requires `C[set[set[set[int]]]]`, etc. Each level has different type specializations, so the existing cycle detection (using full types as cache keys) didn't catch the infinite recursion. This fix adds a simple recursion depth limit (64) to the CycleDetector. When the depth exceeds the limit, we return the fallback value (assume compatible) to safely terminate the recursion. This is a bit of a blunt hammer, but it should be broadly effective to prevent stack overflow in any nested-relation case, and it's hard to imagine that non-recursive nested relation comparisons of depth > 64 exist much in the wild. ## Test Plan Added mdtest.
This commit is contained in:
parent
4e4d018344
commit
8727a7b179
|
|
@ -3010,6 +3010,31 @@ class Bar(Protocol[S]):
|
||||||
z: S | Bar[S]
|
z: S | Bar[S]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Recursive generic protocols with growing specializations
|
||||||
|
|
||||||
|
This snippet caused a stack overflow in <https://github.com/astral-sh/ty/issues/1736> because the
|
||||||
|
type parameter grows with each recursive call (`C[set[T]]` leads to `C[set[set[T]]]`, then
|
||||||
|
`C[set[set[set[T]]]]`, etc.):
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[environment]
|
||||||
|
python-version = "3.12"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class C[T](Protocol):
|
||||||
|
a: "C[set[T]]"
|
||||||
|
|
||||||
|
def takes_c(c: C[set[int]]) -> None: ...
|
||||||
|
def f(c: C[int]) -> None:
|
||||||
|
# The key thing is that we don't stack overflow while checking this.
|
||||||
|
# The cycle detection assumes compatibility when it detects potential
|
||||||
|
# infinite recursion between protocol specializations.
|
||||||
|
takes_c(c)
|
||||||
|
```
|
||||||
|
|
||||||
### Recursive legacy generic protocol
|
### Recursive legacy generic protocol
|
||||||
|
|
||||||
```py
|
```py
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@
|
||||||
//! of the Rust types implementing protocols also call `visitor.visit`. The best way to avoid this
|
//! of the Rust types implementing protocols also call `visitor.visit`. The best way to avoid this
|
||||||
//! is to prefer always calling `visitor.visit` only in the main recursive method on `Type`.
|
//! is to prefer always calling `visitor.visit` only in the main recursive method on `Type`.
|
||||||
|
|
||||||
use std::cell::RefCell;
|
use std::cell::{Cell, RefCell};
|
||||||
use std::cmp::Eq;
|
use std::cmp::Eq;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
@ -29,6 +29,22 @@ use rustc_hash::FxHashMap;
|
||||||
use crate::FxIndexSet;
|
use crate::FxIndexSet;
|
||||||
use crate::types::Type;
|
use crate::types::Type;
|
||||||
|
|
||||||
|
/// Maximum recursion depth for cycle detection.
|
||||||
|
///
|
||||||
|
/// This is a safety limit to prevent stack overflow when checking recursive generic protocols
|
||||||
|
/// that create infinitely growing type specializations. For example:
|
||||||
|
///
|
||||||
|
/// ```python
|
||||||
|
/// class C[T](Protocol):
|
||||||
|
/// a: 'C[set[T]]'
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// When checking `C[set[int]]` against e.g. `C[Unknown]`, member `a` requires checking
|
||||||
|
/// `C[set[set[int]]]`, which in turn requires checking `C[set[set[set[int]]]]`, etc. Each level
|
||||||
|
/// creates a unique cache key, so the standard cycle detection doesn't catch it. The depth limit
|
||||||
|
/// ensures we bail out before hitting a stack overflow.
|
||||||
|
const MAX_RECURSION_DEPTH: u32 = 64;
|
||||||
|
|
||||||
pub(crate) type TypeTransformer<'db, Tag> = CycleDetector<Tag, Type<'db>, Type<'db>>;
|
pub(crate) type TypeTransformer<'db, Tag> = CycleDetector<Tag, Type<'db>, Type<'db>>;
|
||||||
|
|
||||||
impl<Tag> Default for TypeTransformer<'_, Tag> {
|
impl<Tag> Default for TypeTransformer<'_, Tag> {
|
||||||
|
|
@ -58,6 +74,10 @@ pub struct CycleDetector<Tag, T, R> {
|
||||||
/// sort-of defeat the point of a cache if we did!)
|
/// sort-of defeat the point of a cache if we did!)
|
||||||
cache: RefCell<FxHashMap<T, R>>,
|
cache: RefCell<FxHashMap<T, R>>,
|
||||||
|
|
||||||
|
/// Current recursion depth. Used to prevent stack overflow if recursive generic types create
|
||||||
|
/// infinitely growing type specializations that don't trigger exact-match cycle detection.
|
||||||
|
depth: Cell<u32>,
|
||||||
|
|
||||||
fallback: R,
|
fallback: R,
|
||||||
|
|
||||||
_tag: PhantomData<Tag>,
|
_tag: PhantomData<Tag>,
|
||||||
|
|
@ -68,6 +88,7 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
|
||||||
CycleDetector {
|
CycleDetector {
|
||||||
seen: RefCell::new(FxIndexSet::default()),
|
seen: RefCell::new(FxIndexSet::default()),
|
||||||
cache: RefCell::new(FxHashMap::default()),
|
cache: RefCell::new(FxHashMap::default()),
|
||||||
|
depth: Cell::new(0),
|
||||||
fallback,
|
fallback,
|
||||||
_tag: PhantomData,
|
_tag: PhantomData,
|
||||||
}
|
}
|
||||||
|
|
@ -83,7 +104,18 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
|
||||||
return self.fallback.clone();
|
return self.fallback.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check depth limit to prevent stack overflow from recursive generic types
|
||||||
|
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
|
||||||
|
let current_depth = self.depth.get();
|
||||||
|
if current_depth >= MAX_RECURSION_DEPTH {
|
||||||
|
self.seen.borrow_mut().pop();
|
||||||
|
return self.fallback.clone();
|
||||||
|
}
|
||||||
|
self.depth.set(current_depth + 1);
|
||||||
|
|
||||||
let ret = func();
|
let ret = func();
|
||||||
|
|
||||||
|
self.depth.set(current_depth);
|
||||||
self.seen.borrow_mut().pop();
|
self.seen.borrow_mut().pop();
|
||||||
self.cache.borrow_mut().insert(item, ret.clone());
|
self.cache.borrow_mut().insert(item, ret.clone());
|
||||||
|
|
||||||
|
|
@ -100,7 +132,18 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
|
||||||
return Some(self.fallback.clone());
|
return Some(self.fallback.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check depth limit to prevent stack overflow from recursive generic protocols
|
||||||
|
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
|
||||||
|
let current_depth = self.depth.get();
|
||||||
|
if current_depth >= MAX_RECURSION_DEPTH {
|
||||||
|
self.seen.borrow_mut().pop();
|
||||||
|
return Some(self.fallback.clone());
|
||||||
|
}
|
||||||
|
self.depth.set(current_depth + 1);
|
||||||
|
|
||||||
let ret = func()?;
|
let ret = func()?;
|
||||||
|
|
||||||
|
self.depth.set(current_depth);
|
||||||
self.seen.borrow_mut().pop();
|
self.seen.borrow_mut().pop();
|
||||||
self.cache.borrow_mut().insert(item, ret.clone());
|
self.cache.borrow_mut().insert(item, ret.clone());
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue