From 0f18a08a0ae848c92bc5524cb17077a9968bbc69 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Fri, 19 Dec 2025 12:36:37 -0500 Subject: [PATCH] [ty] Respect intersections in iterations (#21965) ## Summary This PR implements the strategy described in https://github.com/astral-sh/ty/issues/1871: we iterate over the positive types, resolve them, then intersect the results. --- .../resources/mdtest/loops/for.md | 124 ++++++++++++++++++ crates/ty_python_semantic/src/types.rs | 20 ++- crates/ty_python_semantic/src/types/tuple.rs | 80 ++++++++++- 3 files changed, 221 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/loops/for.md b/crates/ty_python_semantic/resources/mdtest/loops/for.md index ed51e51c56..3916fa884a 100644 --- a/crates/ty_python_semantic/resources/mdtest/loops/for.md +++ b/crates/ty_python_semantic/resources/mdtest/loops/for.md @@ -337,6 +337,130 @@ for x in Test(): reveal_type(x) # revealed: int ``` +## Intersection type via isinstance narrowing + +When we have an intersection type via `isinstance` narrowing, we should be able to infer the +iterable element type precisely: + +```py +from typing import Sequence + +def _(x: Sequence[int], y: object): + reveal_type(x) # revealed: Sequence[int] + for item in x: + reveal_type(item) # revealed: int + + if isinstance(y, list): + reveal_type(y) # revealed: Top[list[Unknown]] + for item in y: + reveal_type(item) # revealed: object + + if isinstance(x, list): + reveal_type(x) # revealed: Sequence[int] & Top[list[Unknown]] + for item in x: + # int & object simplifies to int + reveal_type(item) # revealed: int +``` + +## Intersection where some elements are not iterable + +When iterating over an intersection type, we should only fail if all positive elements fail to +iterate. If some elements are iterable and some are not, we should iterate over the iterable ones +and intersect their element types. + +```py +from ty_extensions import Intersection + +class NotIterable: + pass + +def _(x: Intersection[list[int], NotIterable]): + # `list[int]` is iterable (yielding `int`), but `NotIterable` is not. + # We should still be able to iterate over the intersection. + for item in x: + reveal_type(item) # revealed: int +``` + +## Intersection where all elements are not iterable + +When iterating over an intersection type where all positive elements are not iterable, we should +fail to iterate. + +```py +from ty_extensions import Intersection + +class NotIterable1: + pass + +class NotIterable2: + pass + +def _(x: Intersection[NotIterable1, NotIterable2]): + # error: [not-iterable] + for item in x: + reveal_type(item) # revealed: Unknown +``` + +## Intersection of fixed-length tuples + +When iterating over an intersection of two fixed-length tuples with the same length, we should +intersect the element types position-by-position. + +```py +from ty_extensions import Intersection + +def _(x: Intersection[tuple[int, str], tuple[object, object]]): + # `tuple[int, str]` yields `int | str` when iterated. + # `tuple[object, object]` yields `object` when iterated. + # The intersection should yield `(int & object) | (str & object)` = `int | str`. + for item in x: + reveal_type(item) # revealed: int | str +``` + +## Intersection of fixed-length tuple with homogeneous iterable + +When iterating over an intersection of a fixed-length tuple with a class that implements `__iter__` +returning a homogeneous iterator, we should preserve the fixed-length structure and intersect each +element type with the iterator's element type. + +```py +from collections.abc import Iterator + +class Foo: + def __iter__(self) -> Iterator[object]: + raise NotImplementedError + +def _(x: tuple[int, str, bytes]): + if isinstance(x, Foo): + # The intersection `tuple[int, str, bytes] & Foo` should iterate as + # `tuple[int & object, str & object, bytes & object]` = `tuple[int, str, bytes]` + a, b, c = x + reveal_type(a) # revealed: int + reveal_type(b) # revealed: str + reveal_type(c) # revealed: bytes + reveal_type(tuple(x)) # revealed: tuple[int, str, bytes] +``` + +## Intersection of homogeneous iterables + +When iterating over an intersection of two types that both yield homogeneous variable-length tuple +specs, we should intersect their element types. + +```py +from collections.abc import Iterator + +class Foo: + def __iter__(self) -> Iterator[object]: + raise NotImplementedError + +def _(x: list[int]): + if isinstance(x, Foo): + # `list[int]` yields `int`, `Foo` yields `object`. + # The intersection should yield `int & object` = `int`. + for item in x: + reveal_type(item) # revealed: int +``` + ## Possibly-not-callable `__iter__` method ```py diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index aa70d9bde6..40451a0736 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -6718,6 +6718,25 @@ impl<'db> Type<'db> { None } } + Type::Intersection(intersection) => { + // For intersections, we iterate over each positive element and intersect + // the resulting element types. Negative elements don't affect iteration. + // We only fail if all elements fail to iterate; as long as at least one + // element can be iterated over, we can produce a result. + let mut specs_iter = intersection + .positive_elements_or_object(db) + .filter_map(|element| { + element + .try_iterate_with_mode(db, EvaluationMode::Sync) + .ok() + }); + let first_spec = specs_iter.next()?; + let mut builder = TupleSpecBuilder::from(&*first_spec); + for spec in specs_iter { + builder = builder.intersect(db, &spec); + } + Some(Cow::Owned(builder.build())) + } // N.B. These special cases aren't strictly necessary, they're just obvious optimizations Type::LiteralString | Type::Dynamic(_) => Some(Cow::Owned(TupleSpec::homogeneous(ty))), @@ -6740,7 +6759,6 @@ impl<'db> Type<'db> { | Type::SpecialForm(_) | Type::KnownInstance(_) | Type::PropertyInstance(_) - | Type::Intersection(_) | Type::AlwaysTruthy | Type::AlwaysFalsy | Type::IntLiteral(_) diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index f0cc3bedc8..1a84ec0ea3 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -29,8 +29,8 @@ use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::generics::InferableTypeVars; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation, - UnionBuilder, UnionType, + IntersectionType, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, + TypeRelation, UnionBuilder, UnionType, }; use crate::types::{Truthiness, TypeContext}; use crate::{Db, FxOrderSet, Program}; @@ -1649,6 +1649,7 @@ pub(crate) enum ResizeTupleError { } /// A builder for creating a new [`TupleSpec`] +#[derive(Clone)] pub(crate) enum TupleSpecBuilder<'db> { Fixed(Vec>), Variable { @@ -1787,6 +1788,81 @@ impl<'db> TupleSpecBuilder<'db> { } } + /// Return a new tuple-spec builder that reflects the intersection of this tuple and another tuple. + /// + /// For example, if `self` is a tuple-spec builder for `tuple[int, str]` and `other` is a + /// tuple-spec for `tuple[object, object]`, the result will be a tuple-spec builder for + /// `tuple[int, str]` (since `int & object` simplifies to `int`, and `str & object` to `str`). + pub(crate) fn intersect(mut self, db: &'db dyn Db, other: &TupleSpec<'db>) -> Self { + match (&mut self, other) { + // Both fixed-length with the same length: element-wise intersection. + (TupleSpecBuilder::Fixed(our_elements), TupleSpec::Fixed(new_elements)) + if our_elements.len() == new_elements.len() => + { + for (existing, new) in our_elements.iter_mut().zip(new_elements.elements()) { + *existing = IntersectionType::from_elements(db, [*existing, *new]); + } + return self; + } + + (TupleSpecBuilder::Fixed(our_elements), TupleSpec::Variable(var)) => { + if let Ok(tuple) = var.resize(db, TupleLength::Fixed(our_elements.len())) { + return self.intersect(db, &tuple); + } + } + + (TupleSpecBuilder::Variable { .. }, TupleSpec::Fixed(fixed)) => { + if let Ok(tuple) = self + .clone() + .build() + .resize(db, TupleLength::Fixed(fixed.len())) + { + return TupleSpecBuilder::from(&tuple).intersect(db, other); + } + } + + ( + TupleSpecBuilder::Variable { + prefix, + variable, + suffix, + }, + TupleSpec::Variable(var), + ) => { + if prefix.len() == var.prefix.len() && suffix.len() == var.suffix.len() { + for (existing, new) in prefix.iter_mut().zip(var.prefix_elements()) { + *existing = IntersectionType::from_elements(db, [*existing, *new]); + } + *variable = IntersectionType::from_elements(db, [*variable, var.variable]); + for (existing, new) in suffix.iter_mut().zip(var.suffix_elements()) { + *existing = IntersectionType::from_elements(db, [*existing, *new]); + } + return self; + } + + let self_built = self.clone().build(); + let self_len = self_built.len(); + if let Ok(resized) = var.resize(db, self_len) { + return self.intersect(db, &resized); + } else if let Ok(resized) = self_built.resize(db, var.len()) { + return TupleSpecBuilder::from(&resized).intersect(db, other); + } + } + + _ => {} + } + + // TODO: probably incorrect? `tuple[int, str] & tuple[int, str, bytes]` should resolve to `Never`. + // So maybe this function should be fallible (return an `Option`)? + let intersected = + IntersectionType::from_elements(db, self.all_elements().chain(other.all_elements())); + TupleSpecBuilder::Variable { + prefix: vec![], + variable: intersected, + suffix: vec![], + } + } + pub(super) fn build(self) -> TupleSpec<'db> { match self { TupleSpecBuilder::Fixed(elements) => {