diff --git a/crates/ty_python_semantic/resources/mdtest/literal/collections/generator_expressions.md b/crates/ty_python_semantic/resources/mdtest/literal/collections/generator_expressions.md index 1e266d92a8..bd2ccf14b0 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal/collections/generator_expressions.md +++ b/crates/ty_python_semantic/resources/mdtest/literal/collections/generator_expressions.md @@ -1,6 +1,62 @@ # Generator expressions +## Basic + +We infer specialized `GeneratorType` instance types for generator expressions: + ```py -# revealed: GeneratorType[@Todo(generator expression yield type), @Todo(generator expression send type), @Todo(generator expression return type)] -reveal_type((x for x in range(42))) +# revealed: GeneratorType[int, None, None] +reveal_type(x for x in range(10)) + +# revealed: GeneratorType[tuple[int, str], None, None] +reveal_type((x, str(y)) for x in range(3) for y in range(3)) +``` + +When used in a loop, the yielded type can be inferred: + +```py +squares = (x**2 for x in range(10)) + +for s in squares: + reveal_type(s) # revealed: int +``` + +`GeneratorType` is covariant in its yielded type, so it can be used where a wider yielded type is +expected: + +```py +from typing import Iterator + +def process_numbers(x: Iterator[float]): ... + +numbers = (x for x in range(10)) +reveal_type(numbers) # revealed: GeneratorType[int, None, None] +process_numbers(numbers) +``` + +## Async generators + +For async generator expressions, we infer specialized `AsyncGeneratorType` instance types: + +```py +import asyncio +from typing import AsyncGenerator + +async def slow_numbers() -> AsyncGenerator[int, None]: + current = 0 + while True: + await asyncio.sleep(1) + yield current + current += 1 + +async def main() -> None: + slow_squares = (x**2 async for x in slow_numbers()) + + reveal_type(slow_squares) # revealed: AsyncGeneratorType[int, None] + + async for s in slow_squares: + reveal_type(s) # revealed: int + print(s) + +asyncio.run(main()) ``` diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 05c6193060..3d573fa27a 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -7392,33 +7392,51 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } /// Infer the type of the `iter` expression of the first comprehension. - fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) { + /// Returns the evaluation mode (async or sync) of the comprehension. + fn infer_first_comprehension_iter( + &mut self, + comprehensions: &[ast::Comprehension], + ) -> EvaluationMode { let mut comprehensions_iter = comprehensions.iter(); let Some(first_comprehension) = comprehensions_iter.next() else { unreachable!("Comprehension must contain at least one generator"); }; self.infer_standalone_expression(&first_comprehension.iter, TypeContext::default()); + + if first_comprehension.is_async { + EvaluationMode::Async + } else { + EvaluationMode::Sync + } } fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> { let ast::ExprGenerator { range: _, node_index: _, - elt: _, + elt, generators, parenthesized: _, } = generator; - self.infer_first_comprehension_iter(generators); + let evaluation_mode = self.infer_first_comprehension_iter(generators); - KnownClass::GeneratorType.to_specialized_instance( - self.db(), - [ - todo_type!("generator expression yield type"), - todo_type!("generator expression send type"), - todo_type!("generator expression return type"), - ], - ) + let scope_id = self + .index + .node_scope(NodeWithScopeRef::GeneratorExpression(generator)); + let scope = scope_id.to_scope_id(self.db(), self.file()); + let inference = infer_scope_types(self.db(), scope); + let yield_type = inference.expression_type(elt.as_ref()); + + if evaluation_mode.is_async() { + KnownClass::AsyncGeneratorType + .to_specialized_instance(self.db(), [yield_type, Type::none(self.db())]) + } else { + KnownClass::GeneratorType.to_specialized_instance( + self.db(), + [yield_type, Type::none(self.db()), Type::none(self.db())], + ) + } } /// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred