Merge branch 'main' into brent/indent-lambda-params

This commit is contained in:
Brent Westbrook 2025-12-09 18:19:58 -05:00
commit 05ff8f5cb1
No known key found for this signature in database
6 changed files with 122 additions and 8 deletions

View File

@ -6,7 +6,11 @@ python-version = "3.13"
python-platform = "linux"
[project]
dependencies = ["sqlmodel==0.0.27"]
dependencies = [
"sqlmodel==0.0.27",
# TODO: remove this pin, once we have a lockfile
"sqlalchemy==2.0.44"
]
```
## Basic model

View File

@ -212,13 +212,46 @@ class Session: ...
async def connect() -> AsyncGenerator[Session]:
yield Session()
# TODO: this should be `() -> _AsyncGeneratorContextManager[Session, None]`
reveal_type(connect) # revealed: () -> _AsyncGeneratorContextManager[Unknown, None]
# revealed: () -> _AsyncGeneratorContextManager[Session, None]
reveal_type(connect)
async def main():
async with connect() as session:
# TODO: should be `Session`
reveal_type(session) # revealed: Unknown
reveal_type(session) # revealed: Session
```
This also works with `AsyncIterator` return types:
```py
from typing import AsyncIterator
@asynccontextmanager
async def connect_iterator() -> AsyncIterator[Session]:
yield Session()
# revealed: () -> _AsyncGeneratorContextManager[Session, None]
reveal_type(connect_iterator)
async def main_iterator():
async with connect_iterator() as session:
reveal_type(session) # revealed: Session
```
And with `AsyncGeneratorType` return types:
```py
from types import AsyncGeneratorType
@asynccontextmanager
async def connect_async_generator() -> AsyncGeneratorType[Session]:
yield Session()
# revealed: () -> _AsyncGeneratorContextManager[Session, None]
reveal_type(connect_async_generator)
async def main_async_generator():
async with connect_async_generator() as session:
reveal_type(session) # revealed: Session
```
## `asyncio.timeout`

View File

@ -319,6 +319,7 @@ pub enum KnownModule {
Tempfile,
Pathlib,
Abc,
Contextlib,
Dataclasses,
Collections,
Inspect,
@ -351,6 +352,7 @@ impl KnownModule {
Self::Tempfile => "tempfile",
Self::Pathlib => "pathlib",
Self::Abc => "abc",
Self::Contextlib => "contextlib",
Self::Dataclasses => "dataclasses",
Self::Collections => "collections",
Self::Inspect => "inspect",

View File

@ -24,7 +24,8 @@ use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Sig
use crate::Program;
use crate::db::Db;
use crate::dunder_all::dunder_all_names;
use crate::place::{Definedness, Place};
use crate::module_resolver::KnownModule;
use crate::place::{Definedness, Place, known_module_symbol};
use crate::types::call::arguments::{Expansion, is_expandable_type};
use crate::types::constraints::ConstraintSet;
use crate::types::diagnostic::{
@ -43,13 +44,14 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleSpec, TupleType};
use crate::types::{
BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature,
BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, CallableType,
CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams,
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind,
enums, list_members, todo_type,
};
use crate::unpack::EvaluationMode;
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@ -941,6 +943,18 @@ impl<'db> Bindings<'db> {
}
}
// TODO: Remove this special handling once we have full support for
// generic protocols in the solver.
Some(KnownFunction::AsyncContextManager) => {
if let [Some(callable)] = overload.parameter_types() {
if let Some(return_ty) =
asynccontextmanager_return_type(db, *callable)
{
overload.set_return_type(return_ty);
}
}
}
Some(KnownFunction::IsProtocol) => {
if let [Some(ty)] = overload.parameter_types() {
// We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol`
@ -4622,3 +4636,55 @@ impl fmt::Display for FunctionKind {
// An example of a routine with many many overloads:
// https://github.com/henribru/google-api-python-client-stubs/blob/master/googleapiclient-stubs/discovery.pyi
const MAXIMUM_OVERLOADS: usize = 50;
/// Infer the return type for a call to `asynccontextmanager`.
///
/// The `@asynccontextmanager` decorator transforms a function that returns (a subtype of) `AsyncIterator[T]`
/// into a function that returns `_AsyncGeneratorContextManager[T]`.
///
/// TODO: This function only handles the most basic case. It should be removed once we have
/// full support for generic protocols in the solver.
fn asynccontextmanager_return_type<'db>(db: &'db dyn Db, func_ty: Type<'db>) -> Option<Type<'db>> {
let bindings = func_ty.bindings(db);
let binding = bindings
.single_element()?
.overloads
.iter()
.exactly_one()
.ok()?;
let signature = &binding.signature;
let return_ty = signature.return_ty?;
let yield_ty = return_ty
.try_iterate_with_mode(db, EvaluationMode::Async)
.ok()?
.homogeneous_element_type(db);
if yield_ty.is_divergent()
|| signature
.parameters()
.iter()
.any(|param| param.annotated_type().is_some_and(|ty| ty.is_divergent()))
{
return Some(yield_ty);
}
let context_manager =
known_module_symbol(db, KnownModule::Contextlib, "_AsyncGeneratorContextManager")
.place
.ignore_possibly_undefined()?
.as_class_literal()?;
let context_manager = context_manager.apply_specialization(db, |generic_context| {
generic_context.specialize_partial(db, [Some(yield_ty), None])
});
let new_return_ty = Type::from(context_manager).to_instance(db)?;
let new_signature = Signature::new(signature.parameters().clone(), Some(new_return_ty));
Some(Type::Callable(CallableType::new(
db,
CallableSignature::single(new_signature),
CallableTypeKind::FunctionLike,
)))
}

View File

@ -1315,6 +1315,10 @@ pub enum KnownFunction {
#[strum(serialize = "abstractmethod")]
AbstractMethod,
/// `contextlib.asynccontextmanager`
#[strum(serialize = "asynccontextmanager")]
AsyncContextManager,
/// `dataclasses.dataclass`
Dataclass,
/// `dataclasses.field`
@ -1402,6 +1406,9 @@ impl KnownFunction {
Self::AbstractMethod => {
matches!(module, KnownModule::Abc)
}
Self::AsyncContextManager => {
matches!(module, KnownModule::Contextlib)
}
Self::Dataclass | Self::Field => {
matches!(module, KnownModule::Dataclasses)
}
@ -1926,6 +1933,8 @@ pub(crate) mod tests {
KnownFunction::AbstractMethod => KnownModule::Abc,
KnownFunction::AsyncContextManager => KnownModule::Contextlib,
KnownFunction::Dataclass | KnownFunction::Field => KnownModule::Dataclasses,
KnownFunction::GetattrStatic => KnownModule::Inspect,

View File

@ -47,7 +47,7 @@ class Repository(NamedTuple):
show_fixes: bool = False
@asynccontextmanager
async def clone(self: Self, checkout_dir: Path) -> AsyncIterator[Path]:
async def clone(self: Self, checkout_dir: Path) -> AsyncIterator[str]:
"""Shallow clone this repository to a temporary directory."""
if checkout_dir.exists():
logger.debug(f"Reusing {self.org}:{self.repo}")