diff --git a/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md b/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md index 7dafa336db..54ab9012c2 100644 --- a/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md +++ b/crates/ty_python_semantic/resources/mdtest/external/sqlmodel.md @@ -6,7 +6,11 @@ python-version = "3.13" python-platform = "linux" [project] -dependencies = ["sqlmodel==0.0.27"] +dependencies = [ + "sqlmodel==0.0.27", + # TODO: remove this pin, once we have a lockfile + "sqlalchemy==2.0.44" +] ``` ## Basic model diff --git a/crates/ty_python_semantic/resources/mdtest/with/async.md b/crates/ty_python_semantic/resources/mdtest/with/async.md index 2a0d7165de..9802c85c4e 100644 --- a/crates/ty_python_semantic/resources/mdtest/with/async.md +++ b/crates/ty_python_semantic/resources/mdtest/with/async.md @@ -212,13 +212,46 @@ class Session: ... async def connect() -> AsyncGenerator[Session]: yield Session() -# TODO: this should be `() -> _AsyncGeneratorContextManager[Session, None]` -reveal_type(connect) # revealed: () -> _AsyncGeneratorContextManager[Unknown, None] +# revealed: () -> _AsyncGeneratorContextManager[Session, None] +reveal_type(connect) async def main(): async with connect() as session: - # TODO: should be `Session` - reveal_type(session) # revealed: Unknown + reveal_type(session) # revealed: Session +``` + +This also works with `AsyncIterator` return types: + +```py +from typing import AsyncIterator + +@asynccontextmanager +async def connect_iterator() -> AsyncIterator[Session]: + yield Session() + +# revealed: () -> _AsyncGeneratorContextManager[Session, None] +reveal_type(connect_iterator) + +async def main_iterator(): + async with connect_iterator() as session: + reveal_type(session) # revealed: Session +``` + +And with `AsyncGeneratorType` return types: + +```py +from types import AsyncGeneratorType + +@asynccontextmanager +async def connect_async_generator() -> AsyncGeneratorType[Session]: + yield Session() + +# revealed: () -> _AsyncGeneratorContextManager[Session, None] +reveal_type(connect_async_generator) + +async def main_async_generator(): + async with connect_async_generator() as session: + reveal_type(session) # revealed: Session ``` ## `asyncio.timeout` diff --git a/crates/ty_python_semantic/src/module_resolver/module.rs b/crates/ty_python_semantic/src/module_resolver/module.rs index 118c2aff45..31c3ba6eef 100644 --- a/crates/ty_python_semantic/src/module_resolver/module.rs +++ b/crates/ty_python_semantic/src/module_resolver/module.rs @@ -319,6 +319,7 @@ pub enum KnownModule { Tempfile, Pathlib, Abc, + Contextlib, Dataclasses, Collections, Inspect, @@ -351,6 +352,7 @@ impl KnownModule { Self::Tempfile => "tempfile", Self::Pathlib => "pathlib", Self::Abc => "abc", + Self::Contextlib => "contextlib", Self::Dataclasses => "dataclasses", Self::Collections => "collections", Self::Inspect => "inspect", diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 72a4818578..29b176ec8a 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -24,7 +24,8 @@ use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Sig use crate::Program; use crate::db::Db; use crate::dunder_all::dunder_all_names; -use crate::place::{Definedness, Place}; +use crate::module_resolver::KnownModule; +use crate::place::{Definedness, Place, known_module_symbol}; use crate::types::call::arguments::{Expansion, is_expandable_type}; use crate::types::constraints::ConstraintSet; use crate::types::diagnostic::{ @@ -43,13 +44,14 @@ use crate::types::generics::{ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::types::tuple::{TupleLength, TupleSpec, TupleType}; use crate::types::{ - BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, + BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, CallableType, CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind, enums, list_members, todo_type, }; +use crate::unpack::EvaluationMode; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; @@ -941,6 +943,18 @@ impl<'db> Bindings<'db> { } } + // TODO: Remove this special handling once we have full support for + // generic protocols in the solver. + Some(KnownFunction::AsyncContextManager) => { + if let [Some(callable)] = overload.parameter_types() { + if let Some(return_ty) = + asynccontextmanager_return_type(db, *callable) + { + overload.set_return_type(return_ty); + } + } + } + Some(KnownFunction::IsProtocol) => { if let [Some(ty)] = overload.parameter_types() { // We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol` @@ -4622,3 +4636,55 @@ impl fmt::Display for FunctionKind { // An example of a routine with many many overloads: // https://github.com/henribru/google-api-python-client-stubs/blob/master/googleapiclient-stubs/discovery.pyi const MAXIMUM_OVERLOADS: usize = 50; + +/// Infer the return type for a call to `asynccontextmanager`. +/// +/// The `@asynccontextmanager` decorator transforms a function that returns (a subtype of) `AsyncIterator[T]` +/// into a function that returns `_AsyncGeneratorContextManager[T]`. +/// +/// TODO: This function only handles the most basic case. It should be removed once we have +/// full support for generic protocols in the solver. +fn asynccontextmanager_return_type<'db>(db: &'db dyn Db, func_ty: Type<'db>) -> Option> { + let bindings = func_ty.bindings(db); + let binding = bindings + .single_element()? + .overloads + .iter() + .exactly_one() + .ok()?; + let signature = &binding.signature; + let return_ty = signature.return_ty?; + + let yield_ty = return_ty + .try_iterate_with_mode(db, EvaluationMode::Async) + .ok()? + .homogeneous_element_type(db); + + if yield_ty.is_divergent() + || signature + .parameters() + .iter() + .any(|param| param.annotated_type().is_some_and(|ty| ty.is_divergent())) + { + return Some(yield_ty); + } + + let context_manager = + known_module_symbol(db, KnownModule::Contextlib, "_AsyncGeneratorContextManager") + .place + .ignore_possibly_undefined()? + .as_class_literal()?; + + let context_manager = context_manager.apply_specialization(db, |generic_context| { + generic_context.specialize_partial(db, [Some(yield_ty), None]) + }); + + let new_return_ty = Type::from(context_manager).to_instance(db)?; + let new_signature = Signature::new(signature.parameters().clone(), Some(new_return_ty)); + + Some(Type::Callable(CallableType::new( + db, + CallableSignature::single(new_signature), + CallableTypeKind::FunctionLike, + ))) +} diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 6b6c798615..dae46bca03 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1315,6 +1315,10 @@ pub enum KnownFunction { #[strum(serialize = "abstractmethod")] AbstractMethod, + /// `contextlib.asynccontextmanager` + #[strum(serialize = "asynccontextmanager")] + AsyncContextManager, + /// `dataclasses.dataclass` Dataclass, /// `dataclasses.field` @@ -1402,6 +1406,9 @@ impl KnownFunction { Self::AbstractMethod => { matches!(module, KnownModule::Abc) } + Self::AsyncContextManager => { + matches!(module, KnownModule::Contextlib) + } Self::Dataclass | Self::Field => { matches!(module, KnownModule::Dataclasses) } @@ -1926,6 +1933,8 @@ pub(crate) mod tests { KnownFunction::AbstractMethod => KnownModule::Abc, + KnownFunction::AsyncContextManager => KnownModule::Contextlib, + KnownFunction::Dataclass | KnownFunction::Field => KnownModule::Dataclasses, KnownFunction::GetattrStatic => KnownModule::Inspect, diff --git a/scripts/check_ecosystem.py b/scripts/check_ecosystem.py index fe2e73d9a1..8bc2dbea55 100755 --- a/scripts/check_ecosystem.py +++ b/scripts/check_ecosystem.py @@ -47,7 +47,7 @@ class Repository(NamedTuple): show_fixes: bool = False @asynccontextmanager - async def clone(self: Self, checkout_dir: Path) -> AsyncIterator[Path]: + async def clone(self: Self, checkout_dir: Path) -> AsyncIterator[str]: """Shallow clone this repository to a temporary directory.""" if checkout_dir.exists(): logger.debug(f"Reusing {self.org}:{self.repo}")