[ty] Respect intersections in iterations (#21965)

## Summary

This PR implements the strategy described in
https://github.com/astral-sh/ty/issues/1871: we iterate over the
positive types, resolve them, then intersect the results.
This commit is contained in:
Charlie Marsh 2025-12-19 12:36:37 -05:00 committed by GitHub
parent b63b3c13fb
commit 0f18a08a0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 221 additions and 3 deletions

View File

@ -337,6 +337,130 @@ for x in Test():
reveal_type(x) # revealed: int
```
## Intersection type via isinstance narrowing
When we have an intersection type via `isinstance` narrowing, we should be able to infer the
iterable element type precisely:
```py
from typing import Sequence
def _(x: Sequence[int], y: object):
reveal_type(x) # revealed: Sequence[int]
for item in x:
reveal_type(item) # revealed: int
if isinstance(y, list):
reveal_type(y) # revealed: Top[list[Unknown]]
for item in y:
reveal_type(item) # revealed: object
if isinstance(x, list):
reveal_type(x) # revealed: Sequence[int] & Top[list[Unknown]]
for item in x:
# int & object simplifies to int
reveal_type(item) # revealed: int
```
## Intersection where some elements are not iterable
When iterating over an intersection type, we should only fail if all positive elements fail to
iterate. If some elements are iterable and some are not, we should iterate over the iterable ones
and intersect their element types.
```py
from ty_extensions import Intersection
class NotIterable:
pass
def _(x: Intersection[list[int], NotIterable]):
# `list[int]` is iterable (yielding `int`), but `NotIterable` is not.
# We should still be able to iterate over the intersection.
for item in x:
reveal_type(item) # revealed: int
```
## Intersection where all elements are not iterable
When iterating over an intersection type where all positive elements are not iterable, we should
fail to iterate.
```py
from ty_extensions import Intersection
class NotIterable1:
pass
class NotIterable2:
pass
def _(x: Intersection[NotIterable1, NotIterable2]):
# error: [not-iterable]
for item in x:
reveal_type(item) # revealed: Unknown
```
## Intersection of fixed-length tuples
When iterating over an intersection of two fixed-length tuples with the same length, we should
intersect the element types position-by-position.
```py
from ty_extensions import Intersection
def _(x: Intersection[tuple[int, str], tuple[object, object]]):
# `tuple[int, str]` yields `int | str` when iterated.
# `tuple[object, object]` yields `object` when iterated.
# The intersection should yield `(int & object) | (str & object)` = `int | str`.
for item in x:
reveal_type(item) # revealed: int | str
```
## Intersection of fixed-length tuple with homogeneous iterable
When iterating over an intersection of a fixed-length tuple with a class that implements `__iter__`
returning a homogeneous iterator, we should preserve the fixed-length structure and intersect each
element type with the iterator's element type.
```py
from collections.abc import Iterator
class Foo:
def __iter__(self) -> Iterator[object]:
raise NotImplementedError
def _(x: tuple[int, str, bytes]):
if isinstance(x, Foo):
# The intersection `tuple[int, str, bytes] & Foo` should iterate as
# `tuple[int & object, str & object, bytes & object]` = `tuple[int, str, bytes]`
a, b, c = x
reveal_type(a) # revealed: int
reveal_type(b) # revealed: str
reveal_type(c) # revealed: bytes
reveal_type(tuple(x)) # revealed: tuple[int, str, bytes]
```
## Intersection of homogeneous iterables
When iterating over an intersection of two types that both yield homogeneous variable-length tuple
specs, we should intersect their element types.
```py
from collections.abc import Iterator
class Foo:
def __iter__(self) -> Iterator[object]:
raise NotImplementedError
def _(x: list[int]):
if isinstance(x, Foo):
# `list[int]` yields `int`, `Foo` yields `object`.
# The intersection should yield `int & object` = `int`.
for item in x:
reveal_type(item) # revealed: int
```
## Possibly-not-callable `__iter__` method
```py

View File

@ -6718,6 +6718,25 @@ impl<'db> Type<'db> {
None
}
}
Type::Intersection(intersection) => {
// For intersections, we iterate over each positive element and intersect
// the resulting element types. Negative elements don't affect iteration.
// We only fail if all elements fail to iterate; as long as at least one
// element can be iterated over, we can produce a result.
let mut specs_iter = intersection
.positive_elements_or_object(db)
.filter_map(|element| {
element
.try_iterate_with_mode(db, EvaluationMode::Sync)
.ok()
});
let first_spec = specs_iter.next()?;
let mut builder = TupleSpecBuilder::from(&*first_spec);
for spec in specs_iter {
builder = builder.intersect(db, &spec);
}
Some(Cow::Owned(builder.build()))
}
// N.B. These special cases aren't strictly necessary, they're just obvious optimizations
Type::LiteralString | Type::Dynamic(_) => Some(Cow::Owned(TupleSpec::homogeneous(ty))),
@ -6740,7 +6759,6 @@ impl<'db> Type<'db> {
| Type::SpecialForm(_)
| Type::KnownInstance(_)
| Type::PropertyInstance(_)
| Type::Intersection(_)
| Type::AlwaysTruthy
| Type::AlwaysFalsy
| Type::IntLiteral(_)

View File

@ -29,8 +29,8 @@ use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::generics::InferableTypeVars;
use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation,
UnionBuilder, UnionType,
IntersectionType, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping,
TypeRelation, UnionBuilder, UnionType,
};
use crate::types::{Truthiness, TypeContext};
use crate::{Db, FxOrderSet, Program};
@ -1649,6 +1649,7 @@ pub(crate) enum ResizeTupleError {
}
/// A builder for creating a new [`TupleSpec`]
#[derive(Clone)]
pub(crate) enum TupleSpecBuilder<'db> {
Fixed(Vec<Type<'db>>),
Variable {
@ -1787,6 +1788,81 @@ impl<'db> TupleSpecBuilder<'db> {
}
}
/// Return a new tuple-spec builder that reflects the intersection of this tuple and another tuple.
///
/// For example, if `self` is a tuple-spec builder for `tuple[int, str]` and `other` is a
/// tuple-spec for `tuple[object, object]`, the result will be a tuple-spec builder for
/// `tuple[int, str]` (since `int & object` simplifies to `int`, and `str & object` to `str`).
pub(crate) fn intersect(mut self, db: &'db dyn Db, other: &TupleSpec<'db>) -> Self {
match (&mut self, other) {
// Both fixed-length with the same length: element-wise intersection.
(TupleSpecBuilder::Fixed(our_elements), TupleSpec::Fixed(new_elements))
if our_elements.len() == new_elements.len() =>
{
for (existing, new) in our_elements.iter_mut().zip(new_elements.elements()) {
*existing = IntersectionType::from_elements(db, [*existing, *new]);
}
return self;
}
(TupleSpecBuilder::Fixed(our_elements), TupleSpec::Variable(var)) => {
if let Ok(tuple) = var.resize(db, TupleLength::Fixed(our_elements.len())) {
return self.intersect(db, &tuple);
}
}
(TupleSpecBuilder::Variable { .. }, TupleSpec::Fixed(fixed)) => {
if let Ok(tuple) = self
.clone()
.build()
.resize(db, TupleLength::Fixed(fixed.len()))
{
return TupleSpecBuilder::from(&tuple).intersect(db, other);
}
}
(
TupleSpecBuilder::Variable {
prefix,
variable,
suffix,
},
TupleSpec::Variable(var),
) => {
if prefix.len() == var.prefix.len() && suffix.len() == var.suffix.len() {
for (existing, new) in prefix.iter_mut().zip(var.prefix_elements()) {
*existing = IntersectionType::from_elements(db, [*existing, *new]);
}
*variable = IntersectionType::from_elements(db, [*variable, var.variable]);
for (existing, new) in suffix.iter_mut().zip(var.suffix_elements()) {
*existing = IntersectionType::from_elements(db, [*existing, *new]);
}
return self;
}
let self_built = self.clone().build();
let self_len = self_built.len();
if let Ok(resized) = var.resize(db, self_len) {
return self.intersect(db, &resized);
} else if let Ok(resized) = self_built.resize(db, var.len()) {
return TupleSpecBuilder::from(&resized).intersect(db, other);
}
}
_ => {}
}
// TODO: probably incorrect? `tuple[int, str] & tuple[int, str, bytes]` should resolve to `Never`.
// So maybe this function should be fallible (return an `Option`)?
let intersected =
IntersectionType::from_elements(db, self.all_elements().chain(other.all_elements()));
TupleSpecBuilder::Variable {
prefix: vec![],
variable: intersected,
suffix: vec![],
}
}
pub(super) fn build(self) -> TupleSpec<'db> {
match self {
TupleSpecBuilder::Fixed(elements) => {