diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 3485304b6b..627492855f 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -145,3 +145,84 @@ def h[T](x: T, cond: bool) -> T | list[T]: def i[T](x: T, cond: bool) -> T | list[T]: return x if cond else [x] ``` + +## Type context sources + +Type context is sourced from various places, including annotated assignments: + +```py +from typing import Literal + +a: list[Literal[1]] = [1] +``` + +Function parameter annotations: + +```py +def b(x: list[Literal[1]]): ... + +b([1]) +``` + +Bound method parameter annotations: + +```py +class C: + def __init__(self, x: list[Literal[1]]): ... + def foo(self, x: list[Literal[1]]): ... + +C([1]).foo([1]) +``` + +Declared variable types: + +```py +d: list[Literal[1]] +d = [1] +``` + +Declared attribute types: + +```py +class E: + e: list[Literal[1]] + +def _(e: E): + # TODO: Implement attribute type context. + # error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to attribute `e` of type `list[Literal[1]]`" + e.e = [1] +``` + +Function return types: + +```py +def f() -> list[Literal[1]]: + return [1] +``` + +## Class constructor parameters + +```toml +[environment] +python-version = "3.12" +``` + +The parameters of both `__init__` and `__new__` are used as type context sources for constructor +calls: + +```py +def f[T](x: T) -> list[T]: + return [x] + +class A: + def __new__(cls, value: list[int | str]): + return super().__new__(cls, value) + + def __init__(self, value: list[int | None]): ... + +A(f(1)) + +# error: [invalid-argument-type] "Argument to function `__new__` is incorrect: Expected `list[int | str]`, found `list[list[Unknown]]`" +# error: [invalid-argument-type] "Argument to bound method `__init__` is incorrect: Expected `list[int | None]`, found `list[list[Unknown]]`" +A(f([])) +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 0d2d7ac866..8946861894 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -6007,6 +6007,9 @@ impl<'db> Type<'db> { /// Given a class literal or non-dynamic `SubclassOf` type, try calling it (creating an instance) /// and return the resulting instance type. /// + /// The `infer_argument_types` closure should be invoked with the signatures of `__new__` and + /// `__init__`, such that the argument types can be inferred with the correct type context. + /// /// Models `type.__call__` behavior. /// TODO: model metaclass `__call__`. /// @@ -6017,10 +6020,10 @@ impl<'db> Type<'db> { /// /// Foo() /// ``` - fn try_call_constructor( + fn try_call_constructor<'ast>( self, db: &'db dyn Db, - argument_types: CallArguments<'_, 'db>, + infer_argument_types: impl FnOnce(Option>) -> CallArguments<'ast, 'db>, tcx: TypeContext<'db>, ) -> Result, ConstructorCallError<'db>> { debug_assert!(matches!( @@ -6076,11 +6079,63 @@ impl<'db> Type<'db> { // easy to check if that's the one we found? // Note that `__new__` is a static method, so we must inject the `cls` argument. let new_method = self_type.lookup_dunder_new(db, ()); + + // Construct an instance type that we can use to look up the `__init__` instance method. + // This performs the same logic as `Type::to_instance`, except for generic class literals. + // TODO: we should use the actual return type of `__new__` to determine the instance type + let init_ty = self_type + .to_instance(db) + .expect("type should be convertible to instance type"); + + // Lookup the `__init__` instance method in the MRO. + let init_method = init_ty.member_lookup_with_policy( + db, + "__init__".into(), + MemberLookupPolicy::NO_INSTANCE_FALLBACK | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, + ); + + // Infer the call argument types, using both `__new__` and `__init__` for type-context. + let bindings = match ( + new_method.as_ref().map(|method| &method.place), + &init_method.place, + ) { + (Some(Place::Defined(new_method, ..)), Place::Undefined) => Some( + new_method + .bindings(db) + .map(|binding| binding.with_bound_type(self_type)), + ), + + (Some(Place::Undefined) | None, Place::Defined(init_method, ..)) => { + Some(init_method.bindings(db)) + } + + (Some(Place::Defined(new_method, ..)), Place::Defined(init_method, ..)) => { + let callable = UnionBuilder::new(db) + .add(*new_method) + .add(*init_method) + .build(); + + let new_method_bindings = new_method + .bindings(db) + .map(|binding| binding.with_bound_type(self_type)); + + Some(Bindings::from_union( + callable, + [new_method_bindings, init_method.bindings(db)], + )) + } + + _ => None, + }; + + let argument_types = infer_argument_types(bindings); + let new_call_outcome = new_method.and_then(|new_method| { match new_method.place.try_call_dunder_get(db, self_type) { Place::Defined(new_method, _, boundness) => { let result = new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref()); + if boundness == Definedness::PossiblyUndefined { Some(Err(DunderNewCallError::PossiblyUnbound(result.err()))) } else { @@ -6091,24 +6146,7 @@ impl<'db> Type<'db> { } }); - // Construct an instance type that we can use to look up the `__init__` instance method. - // This performs the same logic as `Type::to_instance`, except for generic class literals. - // TODO: we should use the actual return type of `__new__` to determine the instance type - let init_ty = self_type - .to_instance(db) - .expect("type should be convertible to instance type"); - - let init_call_outcome = if new_call_outcome.is_none() - || !init_ty - .member_lookup_with_policy( - db, - "__init__".into(), - MemberLookupPolicy::NO_INSTANCE_FALLBACK - | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, - ) - .place - .is_undefined() - { + let init_call_outcome = if new_call_outcome.is_none() || !init_method.is_undefined() { Some(init_ty.try_call_dunder(db, "__init__", argument_types, tcx)) } else { None diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index e6121da055..e72d798dd6 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -100,6 +100,14 @@ impl<'db> Bindings<'db> { self.elements.iter() } + pub(crate) fn map(self, f: impl Fn(CallableBinding<'db>) -> CallableBinding<'db>) -> Self { + Self { + callable_type: self.callable_type, + argument_forms: self.argument_forms, + elements: self.elements.into_iter().map(f).collect(), + } + } + /// Match the arguments of a call site against the parameters of a collection of possibly /// unioned, possibly overloaded signatures. /// diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 126f1ad557..9a5091f837 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -6798,9 +6798,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .to_class_type(self.db()) .is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class)) { - let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; - self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); - if matches!( class.known(self.db()), Some(KnownClass::TypeVar | KnownClass::ExtensionsTypeVar) @@ -6819,8 +6816,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } + let db = self.db(); + let infer_call_arguments = |bindings: Option>| { + if let Some(bindings) = bindings { + let bindings = bindings.match_parameters(self.db(), &call_arguments); + self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); + } else { + let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; + self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); + } + + call_arguments + }; + return callable_type - .try_call_constructor(self.db(), call_arguments, tcx) + .try_call_constructor(db, infer_call_arguments, tcx) .unwrap_or_else(|err| { err.report_diagnostic(&self.context, callable_type, call_expression.into()); err.return_type()