diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index d97092720b..3865572726 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -310,6 +310,65 @@ reveal_type(s) # revealed: list[Literal[1]] reveal_type(s) # revealed: list[Literal[1]] ``` +## Generic constructor annotations are understood + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Any + +class X[T]: + def __init__(self, value: T): + self.value = value + +a: X[int] = X(1) +reveal_type(a) # revealed: X[int] + +b: X[int | None] = X(1) +reveal_type(b) # revealed: X[int | None] + +c: X[int | None] | None = X(1) +reveal_type(c) # revealed: X[int | None] + +def _[T](a: X[T]): + b: X[T | int] = X(a.value) + reveal_type(b) # revealed: X[T@_ | int] + +d: X[Any] = X(1) +reveal_type(d) # revealed: X[Any] + +def _(flag: bool): + # TODO: Handle unions correctly. + # error: [invalid-assignment] "Object of type `X[int]` is not assignable to `X[int | None]`" + a: X[int | None] = X(1) if flag else X(2) + reveal_type(a) # revealed: X[int | None] +``` + +```py +from dataclasses import dataclass + +@dataclass +class Y[T]: + value: T + +y1: Y[Any] = Y(value=1) +# TODO: This should reveal `Y[Any]`. +reveal_type(y1) # revealed: Y[int] +``` + +```py +class Z[T]: + def __new__(cls, value: T): + return super().__new__(cls) + +z1: Z[Any] = Z(1) +# TODO: This should reveal `Z[Any]`. +reveal_type(z1) # revealed: Z[int] +``` + ## PEP-604 annotations are supported ```py diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md b/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md index f692c59835..5786b465cd 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md @@ -206,7 +206,7 @@ dd: defaultdict[int, int] = defaultdict(int) dd[0] = 0 cm: ChainMap[int, int] = ChainMap({1: 1}, {0: 0}) cm[0] = 0 -reveal_type(cm) # revealed: ChainMap[Unknown | int, Unknown | int] +reveal_type(cm) # revealed: ChainMap[int | Unknown, int | Unknown] reveal_type(l[0]) # revealed: Literal[0] reveal_type(d[0]) # revealed: Literal[0] diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index d49cf0087d..c9641c4e34 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1111,6 +1111,22 @@ impl<'db> Type<'db> { } } + /// If the type is a generic class constructor, returns the class instance type. + pub(crate) fn synthesized_constructor_return_ty(self, db: &'db dyn Db) -> Option> { + // TODO: This does not correctly handle unions or intersections. It also does not handle + // constructors that are not represented as bound methods, e.g. `__new__`, or synthesized + // dataclass initializers. + if let Type::BoundMethod(method) = self + && let Type::NominalInstance(instance) = method.self_instance(db) + && method.function(db).name(db).as_str() == "__init__" + { + let class_ty = instance.class_literal(db).identity_specialization(db); + Some(Type::instance(db, class_ty)) + } else { + None + } + } + pub const fn is_property_instance(&self) -> bool { matches!(self, Type::PropertyInstance(..)) } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index db1d06c32e..423783b420 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2687,6 +2687,7 @@ struct ArgumentTypeChecker<'a, 'db> { arguments: &'a CallArguments<'a, 'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], + callable_type: Type<'db>, call_expression_tcx: TypeContext<'db>, return_ty: Type<'db>, errors: &'a mut Vec>, @@ -2703,6 +2704,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { arguments: &'a CallArguments<'a, 'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], + callable_type: Type<'db>, call_expression_tcx: TypeContext<'db>, return_ty: Type<'db>, errors: &'a mut Vec>, @@ -2713,6 +2715,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { arguments, argument_matches, parameter_tys, + callable_type, call_expression_tcx, return_ty, errors, @@ -2754,8 +2757,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { }; let return_with_tcx = self - .signature - .return_ty + .callable_type + .synthesized_constructor_return_ty(self.db) + .or(self.signature.return_ty) .zip(self.call_expression_tcx.annotation); self.inferable_typevars = generic_context.inferable_typevars(self.db); @@ -2763,7 +2767,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // Prefer the declared type of generic classes. let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| { - tcx.class_specialization(self.db)?; + tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()) + .class_specialization(self.db)?; + builder.infer(return_ty, tcx).ok()?; Some(builder.type_mappings().clone()) }); @@ -3196,6 +3202,7 @@ impl<'db> Binding<'db> { arguments, &self.argument_matches, &mut self.parameter_tys, + self.callable_type, call_expression_tcx, self.return_ty, &mut self.errors, diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 74e8aca604..c227412752 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -6025,9 +6025,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // TODO: Checking assignability against the full declared type could help avoid // cases where the constraint solver is not smart enough to solve complex unions. // We should see revisit this after the new constraint solver is implemented. - if !speculated_bindings - .return_type(db) - .is_assignable_to(db, narrowed_ty) + if speculated_bindings + .callable_type() + .synthesized_constructor_return_ty(db) + .is_none() + && !speculated_bindings + .return_type(db) + .is_assignable_to(db, narrowed_ty) { return None; }