From a9be810c383f6de00e2c4ad0b4456f6db2bd3f09 Mon Sep 17 00:00:00 2001 From: David Peter Date: Tue, 9 Dec 2025 22:49:00 +0100 Subject: [PATCH 1/2] [ty] Type inference for `@asynccontextmanager` (#21876) ## Summary This PR adds special handling for `asynccontextmanager` calls as a temporary solution for https://github.com/astral-sh/ty/issues/1804. We will be able to remove this soon once we have support for generic protocols in the solver. closes https://github.com/astral-sh/ty/issues/1804 ## Ecosystem ```diff + tests/test_downloadermiddleware.py:305:56: error[invalid-argument-type] Argument to bound method `download` is incorrect: Expected `Spider`, found `Unknown | Spider | None` + tests/test_downloadermiddleware.py:305:56: warning[possibly-missing-attribute] Attribute `spider` may be missing on object of type `Crawler | None` ``` These look like true positives ```diff + pymongo/asynchronous/database.py:1021:35: error[invalid-assignment] Object of type `(AsyncClientSession & ~AlwaysTruthy & ~AlwaysFalsy) | (_ServerMode & ~AlwaysFalsy) | Unknown | Primary` is not assignable to `_ServerMode | None` + pymongo/asynchronous/database.py:1025:17: error[invalid-argument-type] Argument to bound method `_conn_for_reads` is incorrect: Expected `_ServerMode`, found `_ServerMode | None` ``` Known problems or true positives, just caused by the new type for `session` ```diff - src/integrations/prefect-sqlalchemy/prefect_sqlalchemy/database.py:269:16: error[invalid-return-type] Return type does not match returned value: expected `Connection | AsyncConnection`, found `_GeneratorContextManager[Unknown, None, None] | _AsyncGeneratorContextManager[Unknown, None] | Connection | AsyncConnection` + src/integrations/prefect-sqlalchemy/prefect_sqlalchemy/database.py:269:16: error[invalid-return-type] Return type does not match returned value: expected `Connection | AsyncConnection`, found `_GeneratorContextManager[Unknown, None, None] | _AsyncGeneratorContextManager[AsyncConnection, None] | Connection | AsyncConnection` ``` Just a more concrete type ```diff - src/prefect/flow_engine.py:1277:24: error[missing-argument] No argument provided for required parameter `cls` - src/prefect/server/api/server.py:696:49: error[missing-argument] No argument provided for required parameter `cls` - src/prefect/task_engine.py:1426:24: error[missing-argument] No argument provided for required parameter `cls` ``` Good ## Test Plan * Adapted and newly added Markdown tests * Tested on internal codebase --- .../resources/mdtest/external/sqlmodel.md | 6 +- .../resources/mdtest/with/async.md | 41 +++++++++-- .../src/module_resolver/module.rs | 2 + .../ty_python_semantic/src/types/call/bind.rs | 70 ++++++++++++++++++- .../ty_python_semantic/src/types/function.rs | 9 +++ scripts/check_ecosystem.py | 2 +- 6 files changed, 122 insertions(+), 8 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md b/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md index 7dafa336db..54ab9012c2 100644 --- a/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md +++ b/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md @@ -6,7 +6,11 @@ python-version = "3.13" python-platform = "linux" [project] -dependencies = ["sqlmodel==0.0.27"] +dependencies = [ + "sqlmodel==0.0.27", + # TODO: remove this pin, once we have a lockfile + "sqlalchemy==2.0.44" +] ``` ## Basic model diff --git a/crates/ty_python_semantic/resources/mdtest/with/async.md b/crates/ty_python_semantic/resources/mdtest/with/async.md index 2a0d7165de..9802c85c4e 100644 --- a/crates/ty_python_semantic/resources/mdtest/with/async.md +++ b/crates/ty_python_semantic/resources/mdtest/with/async.md @@ -212,13 +212,46 @@ class Session: ... async def connect() -> AsyncGenerator[Session]: yield Session() -# TODO: this should be `() -> _AsyncGeneratorContextManager[Session, None]` -reveal_type(connect) # revealed: () -> _AsyncGeneratorContextManager[Unknown, None] +# revealed: () -> _AsyncGeneratorContextManager[Session, None] +reveal_type(connect) async def main(): async with connect() as session: - # TODO: should be `Session` - reveal_type(session) # revealed: Unknown + reveal_type(session) # revealed: Session +``` + +This also works with `AsyncIterator` return types: + +```py +from typing import AsyncIterator + +@asynccontextmanager +async def connect_iterator() -> AsyncIterator[Session]: + yield Session() + +# revealed: () -> _AsyncGeneratorContextManager[Session, None] +reveal_type(connect_iterator) + +async def main_iterator(): + async with connect_iterator() as session: + reveal_type(session) # revealed: Session +``` + +And with `AsyncGeneratorType` return types: + +```py +from types import AsyncGeneratorType + +@asynccontextmanager +async def connect_async_generator() -> AsyncGeneratorType[Session]: + yield Session() + +# revealed: () -> _AsyncGeneratorContextManager[Session, None] +reveal_type(connect_async_generator) + +async def main_async_generator(): + async with connect_async_generator() as session: + reveal_type(session) # revealed: Session ``` ## `asyncio.timeout` diff --git a/crates/ty_python_semantic/src/module_resolver/module.rs b/crates/ty_python_semantic/src/module_resolver/module.rs index 118c2aff45..31c3ba6eef 100644 --- a/crates/ty_python_semantic/src/module_resolver/module.rs +++ b/crates/ty_python_semantic/src/module_resolver/module.rs @@ -319,6 +319,7 @@ pub enum KnownModule { Tempfile, Pathlib, Abc, + Contextlib, Dataclasses, Collections, Inspect, @@ -351,6 +352,7 @@ impl KnownModule { Self::Tempfile => "tempfile", Self::Pathlib => "pathlib", Self::Abc => "abc", + Self::Contextlib => "contextlib", Self::Dataclasses => "dataclasses", Self::Collections => "collections", Self::Inspect => "inspect", diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 72a4818578..29b176ec8a 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -24,7 +24,8 @@ use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Sig use crate::Program; use crate::db::Db; use crate::dunder_all::dunder_all_names; -use crate::place::{Definedness, Place}; +use crate::module_resolver::KnownModule; +use crate::place::{Definedness, Place, known_module_symbol}; use crate::types::call::arguments::{Expansion, is_expandable_type}; use crate::types::constraints::ConstraintSet; use crate::types::diagnostic::{ @@ -43,13 +44,14 @@ use crate::types::generics::{ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::types::tuple::{TupleLength, TupleSpec, TupleType}; use crate::types::{ - BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, + BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, CallableType, CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind, enums, list_members, todo_type, }; +use crate::unpack::EvaluationMode; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; @@ -941,6 +943,18 @@ impl<'db> Bindings<'db> { } } + // TODO: Remove this special handling once we have full support for + // generic protocols in the solver. + Some(KnownFunction::AsyncContextManager) => { + if let [Some(callable)] = overload.parameter_types() { + if let Some(return_ty) = + asynccontextmanager_return_type(db, *callable) + { + overload.set_return_type(return_ty); + } + } + } + Some(KnownFunction::IsProtocol) => { if let [Some(ty)] = overload.parameter_types() { // We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol` @@ -4622,3 +4636,55 @@ impl fmt::Display for FunctionKind { // An example of a routine with many many overloads: // https://github.com/henribru/google-api-python-client-stubs/blob/master/googleapiclient-stubs/discovery.pyi const MAXIMUM_OVERLOADS: usize = 50; + +/// Infer the return type for a call to `asynccontextmanager`. +/// +/// The `@asynccontextmanager` decorator transforms a function that returns (a subtype of) `AsyncIterator[T]` +/// into a function that returns `_AsyncGeneratorContextManager[T]`. +/// +/// TODO: This function only handles the most basic case. It should be removed once we have +/// full support for generic protocols in the solver. +fn asynccontextmanager_return_type<'db>(db: &'db dyn Db, func_ty: Type<'db>) -> Option> { + let bindings = func_ty.bindings(db); + let binding = bindings + .single_element()? + .overloads + .iter() + .exactly_one() + .ok()?; + let signature = &binding.signature; + let return_ty = signature.return_ty?; + + let yield_ty = return_ty + .try_iterate_with_mode(db, EvaluationMode::Async) + .ok()? + .homogeneous_element_type(db); + + if yield_ty.is_divergent() + || signature + .parameters() + .iter() + .any(|param| param.annotated_type().is_some_and(|ty| ty.is_divergent())) + { + return Some(yield_ty); + } + + let context_manager = + known_module_symbol(db, KnownModule::Contextlib, "_AsyncGeneratorContextManager") + .place + .ignore_possibly_undefined()? + .as_class_literal()?; + + let context_manager = context_manager.apply_specialization(db, |generic_context| { + generic_context.specialize_partial(db, [Some(yield_ty), None]) + }); + + let new_return_ty = Type::from(context_manager).to_instance(db)?; + let new_signature = Signature::new(signature.parameters().clone(), Some(new_return_ty)); + + Some(Type::Callable(CallableType::new( + db, + CallableSignature::single(new_signature), + CallableTypeKind::FunctionLike, + ))) +} diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 6b6c798615..dae46bca03 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1315,6 +1315,10 @@ pub enum KnownFunction { #[strum(serialize = "abstractmethod")] AbstractMethod, + /// `contextlib.asynccontextmanager` + #[strum(serialize = "asynccontextmanager")] + AsyncContextManager, + /// `dataclasses.dataclass` Dataclass, /// `dataclasses.field` @@ -1402,6 +1406,9 @@ impl KnownFunction { Self::AbstractMethod => { matches!(module, KnownModule::Abc) } + Self::AsyncContextManager => { + matches!(module, KnownModule::Contextlib) + } Self::Dataclass | Self::Field => { matches!(module, KnownModule::Dataclasses) } @@ -1926,6 +1933,8 @@ pub(crate) mod tests { KnownFunction::AbstractMethod => KnownModule::Abc, + KnownFunction::AsyncContextManager => KnownModule::Contextlib, + KnownFunction::Dataclass | KnownFunction::Field => KnownModule::Dataclasses, KnownFunction::GetattrStatic => KnownModule::Inspect, diff --git a/scripts/check_ecosystem.py b/scripts/check_ecosystem.py index fe2e73d9a1..8bc2dbea55 100755 --- a/scripts/check_ecosystem.py +++ b/scripts/check_ecosystem.py @@ -47,7 +47,7 @@ class Repository(NamedTuple): show_fixes: bool = False @asynccontextmanager - async def clone(self: Self, checkout_dir: Path) -> AsyncIterator[Path]: + async def clone(self: Self, checkout_dir: Path) -> AsyncIterator[str]: """Shallow clone this repository to a temporary directory.""" if checkout_dir.exists(): logger.debug(f"Reusing {self.org}:{self.repo}") From f3714fd3c12b2338854b423ae7a5676c13d7becc Mon Sep 17 00:00:00 2001 From: Brent Westbrook <36778786+ntBre@users.noreply.github.com> Date: Tue, 9 Dec 2025 18:15:12 -0500 Subject: [PATCH 2/2] Fix leading comment formatting for lambdas with multiple parameters (#21879) ## Summary This is a follow-up to #21868. As soon as I started merging #21868 into #21385, I realized that I had missed a test case with `**kwargs` after the `*args` parameter. Such a case is supposed to be formatted on one line like: ```py # input ( lambda # comment *x, **y: x ) # output ( lambda # comment *x, **y: x ) ``` which you can still see on the [playground](https://play.ruff.rs/bd88d339-1358-40d2-819f-865bfcb23aef?secondary=Format), but on `main` after #21868, this was formatted as: ```py ( lambda # comment *x, **y: x ) ``` because the leading comment on the first parameter caused the whole group around the parameters to break. Instead of making these comments leading comments on the first parameter, this PR makes them leading comments on the parameters list as a whole. ## Test Plan New tests, and I will also try merging this into #21385 _before_ opening it for review this time.
(labeling `internal` since #21868 should not be released before some kind of fix) --- .../test/fixtures/ruff/expression/lambda.py | 22 ++++++++++ .../src/comments/placement.rs | 21 +++++++--- .../src/expression/expr_lambda.rs | 10 ++--- .../format@expression__lambda.py.snap | 40 +++++++++++++++++++ 4 files changed, 81 insertions(+), 12 deletions(-) diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/lambda.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/lambda.py index 660d5644e9..1b1c1ee3c2 100644 --- a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/lambda.py +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/lambda.py @@ -249,3 +249,25 @@ def a(): x: x ) + +( + lambda + # comment + *x, + **y: x +) + +( + lambda + * # comment 2 + x, + **y: + x +) + +( + lambda + ** # comment 1 + x: + x +) diff --git a/crates/ruff_python_formatter/src/comments/placement.rs b/crates/ruff_python_formatter/src/comments/placement.rs index 28397b6dcf..76449285be 100644 --- a/crates/ruff_python_formatter/src/comments/placement.rs +++ b/crates/ruff_python_formatter/src/comments/placement.rs @@ -871,7 +871,20 @@ fn handle_parameter_comment<'a>( CommentPlacement::Default(comment) } } else if comment.start() < parameter.name.start() { - CommentPlacement::leading(parameter, comment) + // For lambdas, where the parameters cannot be parenthesized and the first parameter thus + // starts at the same position as the parent parameters, mark a comment before the first + // parameter as leading on the parameters rather than the individual parameter to prevent + // the whole parameter list from breaking. + // + // Note that this check is not needed above because lambda parameters cannot have + // annotations. + if let Some(AnyNodeRef::Parameters(parameters)) = comment.enclosing_parent() + && parameters.start() == parameter.start() + { + CommentPlacement::leading(parameters, comment) + } else { + CommentPlacement::leading(parameter, comment) + } } else { CommentPlacement::Default(comment) } @@ -1835,10 +1848,8 @@ fn handle_lambda_comment<'a>( // ) // ``` if comment.start() < parameters.start() { - return if let Some(first) = parameters.iter().next() - && comment.line_position().is_own_line() - { - CommentPlacement::leading(first.as_parameter(), comment) + return if comment.line_position().is_own_line() { + CommentPlacement::leading(parameters, comment) } else { CommentPlacement::dangling(comment.enclosing_node(), comment) }; diff --git a/crates/ruff_python_formatter/src/expression/expr_lambda.rs b/crates/ruff_python_formatter/src/expression/expr_lambda.rs index 335f112323..f91666ecf7 100644 --- a/crates/ruff_python_formatter/src/expression/expr_lambda.rs +++ b/crates/ruff_python_formatter/src/expression/expr_lambda.rs @@ -32,8 +32,8 @@ impl FormatNodeRule for FormatExprLambda { .split_at(dangling.partition_point(|comment| comment.end() < parameters.start())); if dangling_before_parameters.is_empty() { - // If the first parameter has a leading comment, insert a hard line break. This - // comment is associated as a leading comment on the first parameter: + // If the parameters have a leading comment, insert a hard line break. This + // comment is associated as a leading comment on the parameters: // // ```py // ( @@ -86,11 +86,7 @@ impl FormatNodeRule for FormatExprLambda { // *x: x // ) // ``` - if parameters - .iter() - .next() - .is_some_and(|parameter| comments.has_leading(parameter.as_parameter())) - { + if comments.has_leading(&**parameters) { hard_line_break().fmt(f)?; } else { write!(f, [space()])?; diff --git a/crates/ruff_python_formatter/tests/snapshots/format@expression__lambda.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@expression__lambda.py.snap index 3009dfaefc..5997ff539a 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@expression__lambda.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@expression__lambda.py.snap @@ -255,6 +255,28 @@ def a(): x: x ) + +( + lambda + # comment + *x, + **y: x +) + +( + lambda + * # comment 2 + x, + **y: + x +) + +( + lambda + ** # comment 1 + x: + x +) ``` ## Output @@ -513,4 +535,22 @@ def a(): # comment 2 *x: x ) + +( + lambda + # comment + *x, **y: x +) + +( + lambda + # comment 2 + *x, **y: x +) + +( + lambda + # comment 1 + **x: x +) ```