diff --git a/crates/red_knot_python_semantic/resources/mdtest/call/methods.md b/crates/red_knot_python_semantic/resources/mdtest/call/methods.md index aad3c01e84..0d3e94cd92 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/call/methods.md +++ b/crates/red_knot_python_semantic/resources/mdtest/call/methods.md @@ -410,29 +410,19 @@ def does_nothing[T](f: T) -> T: class C: @classmethod - # TODO: no error should be emitted here (needs support for generics) - # error: [invalid-argument-type] @does_nothing def f1(cls: type[C], x: int) -> str: return "a" - # TODO: no error should be emitted here (needs support for generics) - # error: [invalid-argument-type] + @does_nothing @classmethod def f2(cls: type[C], x: int) -> str: return "a" -# TODO: All of these should be `str` (and not emit an error), once we support generics - -# error: [call-non-callable] -reveal_type(C.f1(1)) # revealed: Unknown -# error: [call-non-callable] -reveal_type(C().f1(1)) # revealed: Unknown - -# error: [call-non-callable] -reveal_type(C.f2(1)) # revealed: Unknown -# error: [call-non-callable] -reveal_type(C().f2(1)) # revealed: Unknown +reveal_type(C.f1(1)) # revealed: str +reveal_type(C().f1(1)) # revealed: str +reveal_type(C.f2(1)) # revealed: str +reveal_type(C().f2(1)) # revealed: str ``` [functions and methods]: https://docs.python.org/3/howto/descriptor.html#functions-and-methods diff --git a/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md b/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md index ee767056e8..85574c615c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md @@ -149,23 +149,102 @@ If a typevar does not provide a default, we use `Unknown`: reveal_type(C()) # revealed: C[Unknown] ``` +## Inferring generic class parameters from constructors + If the type of a constructor parameter is a class typevar, we can use that to infer the type -parameter: +parameter. The types inferred from a type context and from a constructor parameter must be +consistent with each other. + +## `__new__` only ```py -class E[T]: - def __init__(self, x: T) -> None: ... +class C[T]: + def __new__(cls, x: T) -> "C"[T]: + return object.__new__(cls) -# TODO: revealed: E[int] or E[Literal[1]] -reveal_type(E(1)) # revealed: E[Unknown] +reveal_type(C(1)) # revealed: C[Literal[1]] + +# TODO: error: [invalid-argument-type] +wrong_innards: C[int] = C("five") ``` -The types inferred from a type context and from a constructor parameter must be consistent with each -other: +## `__init__` only ```py +class C[T]: + def __init__(self, x: T) -> None: ... + +reveal_type(C(1)) # revealed: C[Literal[1]] + # TODO: error: [invalid-argument-type] -wrong_innards: E[int] = E("five") +wrong_innards: C[int] = C("five") +``` + +## Identical `__new__` and `__init__` signatures + +```py +class C[T]: + def __new__(cls, x: T) -> "C"[T]: + return object.__new__(cls) + + def __init__(self, x: T) -> None: ... + +reveal_type(C(1)) # revealed: C[Literal[1]] + +# TODO: error: [invalid-argument-type] +wrong_innards: C[int] = C("five") +``` + +## Compatible `__new__` and `__init__` signatures + +```py +class C[T]: + def __new__(cls, *args, **kwargs) -> "C"[T]: + return object.__new__(cls) + + def __init__(self, x: T) -> None: ... + +reveal_type(C(1)) # revealed: C[Literal[1]] + +# TODO: error: [invalid-argument-type] +wrong_innards: C[int] = C("five") + +class D[T]: + def __new__(cls, x: T) -> "D"[T]: + return object.__new__(cls) + + def __init__(self, *args, **kwargs) -> None: ... + +reveal_type(D(1)) # revealed: D[Literal[1]] + +# TODO: error: [invalid-argument-type] +wrong_innards: D[int] = D("five") +``` + +## `__init__` is itself generic + +TODO: These do not currently work yet, because we don't correctly model the nested generic contexts. + +```py +class C[T]: + def __init__[S](self, x: T, y: S) -> None: ... + +# TODO: no error +# TODO: revealed: C[Literal[1]] +# error: [invalid-argument-type] +reveal_type(C(1, 1)) # revealed: C[Unknown] +# TODO: no error +# TODO: revealed: C[Literal[1]] +# error: [invalid-argument-type] +reveal_type(C(1, "string")) # revealed: C[Unknown] +# TODO: no error +# TODO: revealed: C[Literal[1]] +# error: [invalid-argument-type] +reveal_type(C(1, True)) # revealed: C[Unknown] + +# TODO: error for the correct reason +# error: [invalid-argument-type] "Argument to this function is incorrect: Expected `S`, found `Literal[1]`" +wrong_innards: C[int] = C("five", 1) ``` ## Generic subclass @@ -200,10 +279,7 @@ class C[T]: def cannot_shadow_class_typevar[T](self, t: T): ... c: C[int] = C[int]() -# TODO: no error -# TODO: revealed: str or Literal["string"] -# error: [invalid-argument-type] -reveal_type(c.method("string")) # revealed: U +reveal_type(c.method("string")) # revealed: Literal["string"] ``` ## Cyclic class definition diff --git a/crates/red_knot_python_semantic/resources/mdtest/generics/functions.md b/crates/red_knot_python_semantic/resources/mdtest/generics/functions.md index 935fc47fba..7fe3cf9ef0 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/generics/functions.md +++ b/crates/red_knot_python_semantic/resources/mdtest/generics/functions.md @@ -43,33 +43,14 @@ def absurd[T]() -> T: If the type of a generic function parameter is a typevar, then we can infer what type that typevar is bound to at each call site. -TODO: Note that some of the TODO revealed types have two options, since we haven't decided yet -whether we want to infer a more specific `Literal` type where possible, or use heuristics to weaken -the inferred type to e.g. `int`. - ```py def f[T](x: T) -> T: return x -# TODO: no error -# TODO: revealed: int or Literal[1] -# error: [invalid-argument-type] -reveal_type(f(1)) # revealed: T - -# TODO: no error -# TODO: revealed: float -# error: [invalid-argument-type] -reveal_type(f(1.0)) # revealed: T - -# TODO: no error -# TODO: revealed: bool or Literal[true] -# error: [invalid-argument-type] -reveal_type(f(True)) # revealed: T - -# TODO: no error -# TODO: revealed: str or Literal["string"] -# error: [invalid-argument-type] -reveal_type(f("string")) # revealed: T +reveal_type(f(1)) # revealed: Literal[1] +reveal_type(f(1.0)) # revealed: float +reveal_type(f(True)) # revealed: Literal[True] +reveal_type(f("string")) # revealed: Literal["string"] ``` ## Inferring “deep” generic parameter types @@ -82,7 +63,7 @@ def f[T](x: list[T]) -> T: return x[0] # TODO: revealed: float -reveal_type(f([1.0, 2.0])) # revealed: T +reveal_type(f([1.0, 2.0])) # revealed: Unknown ``` ## Typevar constraints @@ -93,7 +74,6 @@ in the function. ```py def good_param[T: int](x: T) -> None: - # TODO: revealed: T & int reveal_type(x) # revealed: T ``` @@ -162,61 +142,41 @@ parameters simultaneously. def two_params[T](x: T, y: T) -> T: return x -# TODO: no error -# TODO: revealed: str -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(two_params("a", "b")) # revealed: T +reveal_type(two_params("a", "b")) # revealed: Literal["a", "b"] +reveal_type(two_params("a", 1)) # revealed: Literal["a", 1] +``` -# TODO: no error -# TODO: revealed: str | int -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(two_params("a", 1)) # revealed: T +When one of the parameters is a union, we attempt to find the smallest specialization that satisfies +all of the constraints. + +```py +def union_param[T](x: T | None) -> T: + if x is None: + raise ValueError + return x + +reveal_type(union_param("a")) # revealed: Literal["a"] +reveal_type(union_param(1)) # revealed: Literal[1] +reveal_type(union_param(None)) # revealed: Unknown ``` ```py -def param_with_union[T](x: T | int, y: T) -> T: +def union_and_nonunion_params[T](x: T | int, y: T) -> T: return y -# TODO: no error -# TODO: revealed: str -# error: [invalid-argument-type] -reveal_type(param_with_union(1, "a")) # revealed: T - -# TODO: no error -# TODO: revealed: str -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(param_with_union("a", "a")) # revealed: T - -# TODO: no error -# TODO: revealed: int -# error: [invalid-argument-type] -reveal_type(param_with_union(1, 1)) # revealed: T - -# TODO: no error -# TODO: revealed: str | int -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(param_with_union("a", 1)) # revealed: T +reveal_type(union_and_nonunion_params(1, "a")) # revealed: Literal["a"] +reveal_type(union_and_nonunion_params("a", "a")) # revealed: Literal["a"] +reveal_type(union_and_nonunion_params(1, 1)) # revealed: Literal[1] +reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1] +reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1] ``` ```py def tuple_param[T, S](x: T | S, y: tuple[T, S]) -> tuple[T, S]: return y -# TODO: no error -# TODO: revealed: tuple[str, int] -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[T, S] - -# TODO: no error -# TODO: revealed: tuple[str, int] -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[T, S] +reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]] +reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]] ``` ## Inferring nested generic function calls @@ -231,15 +191,6 @@ def f[T](x: T) -> tuple[T, int]: def g[T](x: T) -> T | None: return x -# TODO: no error -# TODO: revealed: tuple[str | None, int] -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(f(g("a"))) # revealed: tuple[T, int] - -# TODO: no error -# TODO: revealed: tuple[str, int] | None -# error: [invalid-argument-type] -# error: [invalid-argument-type] -reveal_type(g(f("a"))) # revealed: T | None +reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int] +reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md b/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md index 31561d442d..b4f9992a9a 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md +++ b/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md @@ -59,14 +59,8 @@ to a different type each time. def f[T](x: T) -> T: return x -# TODO: no error -# TODO: revealed: int or Literal[1] -# error: [invalid-argument-type] -reveal_type(f(1)) # revealed: T -# TODO: no error -# TODO: revealed: str or Literal["a"] -# error: [invalid-argument-type] -reveal_type(f("a")) # revealed: T +reveal_type(f(1)) # revealed: Literal[1] +reveal_type(f("a")) # revealed: Literal["a"] ``` ## Methods can mention class typevars @@ -157,10 +151,7 @@ class C[T]: return y c: C[int] = C() -# TODO: no errors -# TODO: revealed: str -# error: [invalid-argument-type] -reveal_type(c.m(1, "string")) # revealed: S +reveal_type(c.m(1, "string")) # revealed: Literal["string"] ``` ## Unbound typevars diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 9f3b17f3b6..ed0174d273 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -35,10 +35,10 @@ use crate::semantic_index::symbol::ScopeId; use crate::semantic_index::{imported_modules, semantic_index}; use crate::suppression::check_suppressions; use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers}; -use crate::types::call::{Bindings, CallArgumentTypes}; +use crate::types::call::{Bindings, CallArgumentTypes, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; -use crate::types::generics::Specialization; +use crate::types::generics::{GenericContext, Specialization}; use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; @@ -2150,7 +2150,7 @@ impl<'db> Type<'db> { "__get__" | "__set__" | "__delete__", ) => Some(Symbol::Unbound.into()), - _ => Some(class.class_member(db, None, name, policy)), + _ => Some(class.class_member(db, name, policy)), } } @@ -3731,7 +3731,11 @@ impl<'db> Type<'db> { _ => { let signature = CallableSignature::single( self, - Signature::new(Parameters::gradual_form(), self.to_instance(db)), + Signature::new_generic( + class.generic_context(db), + Parameters::gradual_form(), + self.to_instance(db), + ), ); Signatures::single(signature) } @@ -3827,9 +3831,14 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, name: &str, - argument_types: CallArgumentTypes<'_, 'db>, + mut argument_types: CallArgumentTypes<'_, 'db>, ) -> Result, CallDunderError<'db>> { - self.try_call_dunder_with_policy(db, name, argument_types, MemberLookupPolicy::empty()) + self.try_call_dunder_with_policy( + db, + name, + &mut argument_types, + MemberLookupPolicy::NO_INSTANCE_FALLBACK, + ) } /// Same as `try_call_dunder`, but allows specifying a policy for the member lookup. In @@ -3840,21 +3849,17 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, name: &str, - mut argument_types: CallArgumentTypes<'_, 'db>, + argument_types: &mut CallArgumentTypes<'_, 'db>, policy: MemberLookupPolicy, ) -> Result, CallDunderError<'db>> { match self - .member_lookup_with_policy( - db, - name.into(), - MemberLookupPolicy::NO_INSTANCE_FALLBACK | policy, - ) + .member_lookup_with_policy(db, name.into(), policy) .symbol { Symbol::Type(dunder_callable, boundness) => { let signatures = dunder_callable.signatures(db); - let bindings = Bindings::match_parameters(signatures, &mut argument_types) - .check_types(db, &mut argument_types)?; + let bindings = Bindings::match_parameters(signatures, argument_types) + .check_types(db, argument_types)?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } @@ -4025,9 +4030,31 @@ impl<'db> Type<'db> { fn try_call_constructor( self, db: &'db dyn Db, - argument_types: CallArgumentTypes<'_, 'db>, + mut argument_types: CallArgumentTypes<'_, 'db>, ) -> Result, ConstructorCallError<'db>> { - debug_assert!(matches!(self, Type::ClassLiteral(_) | Type::SubclassOf(_))); + debug_assert!(matches!( + self, + Type::ClassLiteral(_) | Type::GenericAlias(_) | Type::SubclassOf(_) + )); + + // If we are trying to construct a non-specialized generic class, we should use the + // constructor parameters to try to infer the class specialization. To do this, we need to + // tweak our member lookup logic a bit. Normally, when looking up a class or instance + // member, we first apply the class's default specialization, and apply that specialization + // to the type of the member. To infer a specialization from the argument types, we need to + // have the class's typevars still in the method signature when we attempt to call it. To + // do this, we instead use the _identity_ specialization, which maps each of the class's + // generic typevars to itself. + let (generic_origin, self_type) = match self { + Type::ClassLiteral(ClassLiteralType::Generic(generic)) => { + let specialization = generic.generic_context(db).identity_specialization(db); + ( + Some(generic), + Type::GenericAlias(GenericAlias::new(db, generic, specialization)), + ) + } + _ => (None, self), + }; // As of now we do not model custom `__call__` on meta-classes, so the code below // only deals with interplay between `__new__` and `__init__` methods. @@ -4052,46 +4079,30 @@ impl<'db> Type<'db> { // `object` we would inadvertently unhide `__new__` on `type`, which is not what we want. // An alternative might be to not skip `object.__new__` but instead mark it such that it's // easy to check if that's the one we found? - let new_call_outcome: Option, CallDunderError<'db>>> = match self - .member_lookup_with_policy( + // Note that `__new__` is a static method, so we must inject the `cls` argument. + let new_call_outcome = argument_types.with_self(Some(self_type), |argument_types| { + let result = self_type.try_call_dunder_with_policy( db, - "__new__".into(), + "__new__", + argument_types, MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK | MemberLookupPolicy::META_CLASS_NO_TYPE_FALLBACK, - ) - .symbol - { - Symbol::Type(dunder_callable, boundness) => { - let signatures = dunder_callable.signatures(db); - // `__new__` is a static method, so we must inject the `cls` argument. - let mut argument_types = argument_types.prepend_synthetic(self); - - Some( - match Bindings::match_parameters(signatures, &mut argument_types) - .check_types(db, &mut argument_types) - { - Ok(bindings) => { - if boundness == Boundness::PossiblyUnbound { - Err(CallDunderError::PossiblyUnbound(Box::new(bindings))) - } else { - Ok(bindings) - } - } - Err(err) => Err(err.into()), - }, - ) + ); + match result { + Err(CallDunderError::MethodNotAvailable) => None, + _ => Some(result), } - // No explicit `__new__` method found - Symbol::Unbound => None, - }; + }); + // Construct an instance type that we can use to look up the `__init__` instance method. + // This performs the same logic as `Type::to_instance`, except for generic class literals. // TODO: we should use the actual return type of `__new__` to determine the instance type - let instance_ty = self + let init_ty = self_type .to_instance(db) - .expect("Class literal type and subclass-of types should always be convertible to instance type"); + .expect("type should be convertible to instance type"); let init_call_outcome = if new_call_outcome.is_none() - || !instance_ty + || !init_ty .member_lookup_with_policy( db, "__init__".into(), @@ -4100,23 +4111,68 @@ impl<'db> Type<'db> { .symbol .is_unbound() { - Some(instance_ty.try_call_dunder(db, "__init__", argument_types)) + Some(init_ty.try_call_dunder(db, "__init__", argument_types)) } else { None }; - match (new_call_outcome, init_call_outcome) { + // Note that we use `self` here, not `self_type`, so that if constructor argument inference + // fails, we fail back to the default specialization. + let instance_ty = self + .to_instance(db) + .expect("type should be convertible to instance type"); + + match (generic_origin, new_call_outcome, init_call_outcome) { // All calls are successful or not called at all - (None | Some(Ok(_)), None | Some(Ok(_))) => Ok(instance_ty), - (None | Some(Ok(_)), Some(Err(error))) => { + ( + Some(generic_origin), + new_call_outcome @ (None | Some(Ok(_))), + init_call_outcome @ (None | Some(Ok(_))), + ) => { + let new_specialization = new_call_outcome + .and_then(Result::ok) + .as_ref() + .and_then(Bindings::single_element) + .and_then(CallableBinding::matching_overload) + .and_then(|(_, binding)| binding.specialization()); + let init_specialization = init_call_outcome + .and_then(Result::ok) + .as_ref() + .and_then(Bindings::single_element) + .and_then(CallableBinding::matching_overload) + .and_then(|(_, binding)| binding.specialization()); + let specialization = match (new_specialization, init_specialization) { + (None, None) => None, + (Some(specialization), None) | (None, Some(specialization)) => { + Some(specialization) + } + (Some(new_specialization), Some(init_specialization)) => { + Some(new_specialization.combine(db, init_specialization)) + } + }; + let specialized = specialization + .map(|specialization| { + Type::instance(ClassType::Generic(GenericAlias::new( + db, + generic_origin, + specialization, + ))) + }) + .unwrap_or(instance_ty); + Ok(specialized) + } + + (None, None | Some(Ok(_)), None | Some(Ok(_))) => Ok(instance_ty), + + (_, None | Some(Ok(_)), Some(Err(error))) => { // no custom `__new__` or it was called and succeeded, but `__init__` failed. Err(ConstructorCallError::Init(instance_ty, error)) } - (Some(Err(error)), None | Some(Ok(_))) => { + (_, Some(Err(error)), None | Some(Ok(_))) => { // custom `__new__` was called and failed, but init is ok Err(ConstructorCallError::New(instance_ty, error)) } - (Some(Err(new_error)), Some(Err(init_error))) => { + (_, Some(Err(new_error)), Some(Err(init_error))) => { // custom `__new__` was called and failed, and `__init__` is also not ok Err(ConstructorCallError::NewAndInit( instance_ty, @@ -5688,6 +5744,9 @@ pub struct FunctionType<'db> { /// A set of special decorators that were applied to this function decorators: FunctionDecorators, + /// The generic context of a generic function. + generic_context: Option>, + /// A specialization that should be applied to the function's parameter and return types, /// either because the function is itself generic, or because it appears in the body of a /// generic class. @@ -5769,13 +5828,25 @@ impl<'db> FunctionType<'db> { let scope = self.body_scope(db); let function_stmt_node = scope.node(db).expect_function(); let definition = self.definition(db); - Signature::from_function(db, definition, function_stmt_node) + Signature::from_function(db, self.generic_context(db), definition, function_stmt_node) } pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { self.known(db) == Some(known_function) } + fn with_generic_context(self, db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self { + Self::new( + db, + self.name(db).clone(), + self.known(db), + self.body_scope(db), + self.decorators(db), + Some(generic_context), + self.specialization(db), + ) + } + fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { let specialization = match self.specialization(db) { Some(existing) => existing.apply_specialization(db, specialization), @@ -5787,6 +5858,7 @@ impl<'db> FunctionType<'db> { self.known(db), self.body_scope(db), self.decorators(db), + self.generic_context(db), Some(specialization), ) } diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index aa3e64e903..411c080257 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -73,21 +73,26 @@ impl<'db> UnionBuilder<'db> { } /// Collapse the union to a single type: `object`. - fn collapse_to_object(mut self) -> Self { + fn collapse_to_object(&mut self) { self.elements.clear(); self.elements .push(UnionElement::Type(Type::object(self.db))); - self } /// Adds a type to this union. pub(crate) fn add(mut self, ty: Type<'db>) -> Self { + self.add_in_place(ty); + self + } + + /// Adds a type to this union. + pub(crate) fn add_in_place(&mut self, ty: Type<'db>) { match ty { Type::Union(union) => { let new_elements = union.elements(self.db); self.elements.reserve(new_elements.len()); for element in new_elements { - self = self.add(*element); + self.add_in_place(*element); } } // Adding `Never` to a union is a no-op. @@ -103,14 +108,15 @@ impl<'db> UnionBuilder<'db> { UnionElement::StringLiterals(literals) => { if literals.len() >= MAX_UNION_LITERALS { let replace_with = KnownClass::Str.to_instance(self.db); - return self.add(replace_with); + self.add_in_place(replace_with); + return; } literals.insert(literal); found = true; break; } UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => { - return self; + return; } _ => {} } @@ -130,14 +136,15 @@ impl<'db> UnionBuilder<'db> { UnionElement::BytesLiterals(literals) => { if literals.len() >= MAX_UNION_LITERALS { let replace_with = KnownClass::Bytes.to_instance(self.db); - return self.add(replace_with); + self.add_in_place(replace_with); + return; } literals.insert(literal); found = true; break; } UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => { - return self; + return; } _ => {} } @@ -157,14 +164,15 @@ impl<'db> UnionBuilder<'db> { UnionElement::IntLiterals(literals) => { if literals.len() >= MAX_UNION_LITERALS { let replace_with = KnownClass::Int.to_instance(self.db); - return self.add(replace_with); + self.add_in_place(replace_with); + return; } literals.insert(literal); found = true; break; } UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => { - return self; + return; } _ => {} } @@ -176,7 +184,7 @@ impl<'db> UnionBuilder<'db> { } // Adding `object` to a union results in `object`. ty if ty.is_object(self.db) => { - return self.collapse_to_object(); + self.collapse_to_object(); } _ => { let bool_pair = if let Type::BooleanLiteral(b) = ty { @@ -223,7 +231,7 @@ impl<'db> UnionBuilder<'db> { || ty.is_subtype_of(self.db, element) || element.is_object(self.db) { - return self; + return; } else if element.is_subtype_of(self.db, ty) { to_remove.push(index); } else if ty_negated.is_subtype_of(self.db, element) { @@ -234,7 +242,8 @@ impl<'db> UnionBuilder<'db> { // `element | ty` must be `object` (object has no other supertypes). This means we can simplify // the whole union to just `object`, since all other potential elements would also be subtypes of // `object`. - return self.collapse_to_object(); + self.collapse_to_object(); + return; } } if let Some((&first, rest)) = to_remove.split_first() { @@ -248,7 +257,6 @@ impl<'db> UnionBuilder<'db> { } } } - self } pub(crate) fn build(self) -> Type<'db> { diff --git a/crates/red_knot_python_semantic/src/types/call.rs b/crates/red_knot_python_semantic/src/types/call.rs index a2761f91cb..27c10e5432 100644 --- a/crates/red_knot_python_semantic/src/types/call.rs +++ b/crates/red_knot_python_semantic/src/types/call.rs @@ -5,7 +5,7 @@ use crate::Db; mod arguments; mod bind; pub(super) use arguments::{Argument, CallArgumentTypes, CallArguments}; -pub(super) use bind::Bindings; +pub(super) use bind::{Bindings, CallableBinding}; /// Wraps a [`Bindings`] for an unsuccessful call with information about why the call was /// unsuccessful. diff --git a/crates/red_knot_python_semantic/src/types/call/arguments.rs b/crates/red_knot_python_semantic/src/types/call/arguments.rs index fd2cd4a313..cce0c81c0b 100644 --- a/crates/red_knot_python_semantic/src/types/call/arguments.rs +++ b/crates/red_knot_python_semantic/src/types/call/arguments.rs @@ -109,21 +109,6 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> { result } - /// Create a new [`CallArgumentTypes`] by prepending a synthetic argument to the front of this - /// argument list. - pub(crate) fn prepend_synthetic(&self, synthetic: Type<'db>) -> Self { - Self { - arguments: CallArguments( - std::iter::once(Argument::Synthetic) - .chain(self.arguments.iter()) - .collect(), - ), - types: std::iter::once(synthetic) - .chain(self.types.iter().copied()) - .collect(), - } - } - pub(crate) fn iter(&self) -> impl Iterator, Type<'db>)> + '_ { self.arguments.iter().zip(self.types.iter().copied()) } diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index da918d2249..16f00029e5 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -16,6 +16,7 @@ use crate::types::diagnostic::{ NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS, UNKNOWN_ARGUMENT, }; +use crate::types::generics::{Specialization, SpecializationBuilder}; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ todo_type, BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction, @@ -147,6 +148,13 @@ impl<'db> Bindings<'db> { self.elements.len() == 1 } + pub(crate) fn single_element(&self) -> Option<&CallableBinding<'db>> { + match self.elements.as_slice() { + [element] => Some(element), + _ => None, + } + } + pub(crate) fn callable_type(&self) -> Type<'db> { self.signatures.callable_type } @@ -882,6 +890,9 @@ pub(crate) struct Binding<'db> { /// Return type of the call. return_ty: Type<'db>, + /// The specialization that was inferred from the argument types, if the callable is generic. + specialization: Option>, + /// The formal parameter that each argument is matched with, in argument source order, or /// `None` if the argument was not matched to any parameter. argument_parameters: Box<[Option]>, @@ -1017,6 +1028,7 @@ impl<'db> Binding<'db> { Self { return_ty: signature.return_ty.unwrap_or(Type::unknown()), + specialization: None, argument_parameters: argument_parameters.into_boxed_slice(), parameter_tys: vec![None; parameters.len()].into_boxed_slice(), errors, @@ -1029,7 +1041,26 @@ impl<'db> Binding<'db> { signature: &Signature<'db>, argument_types: &CallArgumentTypes<'_, 'db>, ) { + // If this overload is generic, first see if we can infer a specialization of the function + // from the arguments that were passed in. let parameters = signature.parameters(); + self.specialization = signature.generic_context.map(|generic_context| { + let mut builder = SpecializationBuilder::new(db, generic_context); + for (argument_index, (_, argument_type)) in argument_types.iter().enumerate() { + let Some(parameter_index) = self.argument_parameters[argument_index] else { + // There was an error with argument when matching parameters, so don't bother + // type-checking it. + continue; + }; + let parameter = ¶meters[parameter_index]; + let Some(expected_type) = parameter.annotated_type() else { + continue; + }; + builder.infer(expected_type, argument_type); + } + builder.build() + }); + let mut num_synthetic_args = 0; let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { if argument_index >= num_synthetic_args { @@ -1052,7 +1083,10 @@ impl<'db> Binding<'db> { continue; }; let parameter = ¶meters[parameter_index]; - if let Some(expected_ty) = parameter.annotated_type() { + if let Some(mut expected_ty) = parameter.annotated_type() { + if let Some(specialization) = self.specialization { + expected_ty = expected_ty.apply_specialization(db, specialization); + } if !argument_type.is_assignable_to(db, expected_ty) { let positional = matches!(argument, Argument::Positional | Argument::Synthetic) && !parameter.is_variadic(); @@ -1074,6 +1108,10 @@ impl<'db> Binding<'db> { self.parameter_tys[parameter_index] = Some(union); } } + + if let Some(specialization) = self.specialization { + self.return_ty = self.return_ty.apply_specialization(db, specialization); + } } pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) { @@ -1084,6 +1122,10 @@ impl<'db> Binding<'db> { self.return_ty } + pub(crate) fn specialization(&self) -> Option> { + self.specialization + } + pub(crate) fn parameter_types(&self) -> &[Option>] { &self.parameter_tys } diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index f69149d55b..0ebdefcde8 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -287,7 +287,7 @@ impl<'db> ClassType<'db> { ) -> SymbolAndQualifiers<'db> { let (class_literal, specialization) = self.class_literal(db); class_literal - .class_member(db, specialization, name, policy) + .class_member_inner(db, specialization, name, policy) .map_type(|ty| self.specialize_type(db, ty)) } @@ -298,9 +298,9 @@ impl<'db> ClassType<'db> { /// directly. Use [`ClassType::class_member`] if you require a method that will /// traverse through the MRO until it finds the member. pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { - let (class_literal, _) = self.class_literal(db); + let (class_literal, specialization) = self.class_literal(db); class_literal - .own_class_member(db, name) + .own_class_member(db, specialization, name) .map_type(|ty| self.specialize_type(db, ty)) } @@ -378,6 +378,13 @@ impl<'db> ClassLiteralType<'db> { self.class(db).known == Some(known_class) } + pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option> { + match self { + Self::NonGeneric(_) => None, + Self::Generic(generic) => Some(generic.generic_context(db)), + } + } + /// Return `true` if this class represents the builtin class `object` pub(crate) fn is_object(self, db: &'db dyn Db) -> bool { self.is_known(db, KnownClass::Object) @@ -696,6 +703,15 @@ impl<'db> ClassLiteralType<'db> { /// /// TODO: Should this be made private...? pub(super) fn class_member( + self, + db: &'db dyn Db, + name: &str, + policy: MemberLookupPolicy, + ) -> SymbolAndQualifiers<'db> { + self.class_member_inner(db, None, name, policy) + } + + fn class_member_inner( self, db: &'db dyn Db, specialization: Option>, @@ -800,7 +816,12 @@ impl<'db> ClassLiteralType<'db> { /// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope /// directly. Use [`ClassLiteralType::class_member`] if you require a method that will /// traverse through the MRO until it finds the member. - pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { + pub(super) fn own_class_member( + self, + db: &'db dyn Db, + specialization: Option>, + name: &str, + ) -> SymbolAndQualifiers<'db> { if let Some(metadata) = self.dataclass_metadata(db) { if name == "__init__" { if metadata.contains(DataclassMetadata::INIT) { @@ -822,7 +843,30 @@ impl<'db> ClassLiteralType<'db> { } let body_scope = self.body_scope(db); - class_symbol(db, body_scope, name) + class_symbol(db, body_scope, name).map_type(|ty| { + // The `__new__` and `__init__` members of a non-specialized generic class are handled + // specially: they inherit the generic context of their class. That lets us treat them + // as generic functions when constructing the class, and infer the specialization of + // the class from the arguments that are passed in. + // + // We might decide to handle other class methods the same way, having them inherit the + // class's generic context, and performing type inference on calls to them to determine + // the specialization of the class. If we do that, we would update this to also apply + // to any method with a `@classmethod` decorator. (`__init__` would remain a special + // case, since it's an _instance_ method where we don't yet know the generic class's + // specialization.) + match (self, ty, specialization, name) { + ( + ClassLiteralType::Generic(origin), + Type::FunctionLiteral(function), + Some(_), + "__new__" | "__init__", + ) => Type::FunctionLiteral( + function.with_generic_context(db, origin.generic_context(db)), + ), + _ => ty, + } + }) } /// Returns the `name` attribute of an instance of this class. diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 1314fabb6f..5a6da80132 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -8,11 +8,11 @@ use ruff_python_literal::escape::AsciiEscape; use crate::types::class::{ClassType, GenericAlias, GenericClass}; use crate::types::class_base::ClassBase; -use crate::types::generics::Specialization; +use crate::types::generics::{GenericContext, Specialization}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ InstanceType, IntersectionType, KnownClass, MethodWrapperKind, StringLiteralType, Type, - TypeVarInstance, UnionType, WrapperDescriptorKind, + TypeVarBoundOrConstraints, TypeVarInstance, UnionType, WrapperDescriptorKind, }; use crate::Db; use rustc_hash::FxHashMap; @@ -256,6 +256,52 @@ impl Display for DisplayGenericAlias<'_> { } } +impl<'db> GenericContext<'db> { + pub fn display(&'db self, db: &'db dyn Db) -> DisplayGenericContext<'db> { + DisplayGenericContext { + typevars: self.variables(db), + db, + } + } +} + +pub struct DisplayGenericContext<'db> { + typevars: &'db [TypeVarInstance<'db>], + db: &'db dyn Db, +} + +impl Display for DisplayGenericContext<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_char('[')?; + for (idx, var) in self.typevars.iter().enumerate() { + if idx > 0 { + f.write_str(", ")?; + } + write!(f, "{}", var.name(self.db))?; + match var.bound_or_constraints(self.db) { + Some(TypeVarBoundOrConstraints::UpperBound(bound)) => { + write!(f, ": {}", bound.display(self.db))?; + } + Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { + f.write_str(": (")?; + for (idx, constraint) in constraints.iter(self.db).enumerate() { + if idx > 0 { + f.write_str(", ")?; + } + write!(f, "{}", constraint.display(self.db))?; + } + f.write_char(')')?; + } + None => {} + } + if let Some(default_type) = var.default_ty(self.db) { + write!(f, " = {}", default_type.display(self.db))?; + } + } + f.write_char(']') + } +} + impl<'db> Specialization<'db> { /// Renders the specialization in full, e.g. `{T = int, U = str}`. pub fn display(&'db self, db: &'db dyn Db) -> DisplaySpecialization<'db> { diff --git a/crates/red_knot_python_semantic/src/types/generics.rs b/crates/red_knot_python_semantic/src/types/generics.rs index 977808c84b..70def100a5 100644 --- a/crates/red_knot_python_semantic/src/types/generics.rs +++ b/crates/red_knot_python_semantic/src/types/generics.rs @@ -1,14 +1,18 @@ use ruff_python_ast as ast; +use rustc_hash::FxHashMap; use crate::semantic_index::SemanticIndex; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ declaration_type, KnownInstanceType, Type, TypeVarBoundOrConstraints, TypeVarInstance, - UnionType, + UnionBuilder, UnionType, }; use crate::Db; /// A list of formal type variables for a generic function, class, or type alias. +/// +/// TODO: Handle nested generic contexts better, with actual parent links to the lexically +/// containing context. #[salsa::tracked(debug)] pub struct GenericContext<'db> { #[return_ref] @@ -85,6 +89,15 @@ impl<'db> GenericContext<'db> { self.specialize(db, types) } + pub(crate) fn identity_specialization(self, db: &'db dyn Db) -> Specialization<'db> { + let types = self + .variables(db) + .iter() + .map(|typevar| Type::TypeVar(*typevar)) + .collect(); + self.specialize(db, types) + } + pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> { let types = vec![Type::unknown(); self.variables(db).len()]; self.specialize(db, types.into()) @@ -100,6 +113,9 @@ impl<'db> GenericContext<'db> { } /// An assignment of a specific type to each type variable in a generic scope. +/// +/// TODO: Handle nested specializations better, with actual parent links to the specialization of +/// the lexically containing context. #[salsa::tracked(debug)] pub struct Specialization<'db> { pub(crate) generic_context: GenericContext<'db>, @@ -130,6 +146,26 @@ impl<'db> Specialization<'db> { Specialization::new(db, self.generic_context(db), types) } + /// Combines two specializations of the same generic context. If either specialization maps a + /// typevar to `Type::Unknown`, the other specialization's mapping is used. If both map the + /// typevar to a known type, those types are unioned together. + /// + /// Panics if the two specializations are not for the same generic context. + pub(crate) fn combine(self, db: &'db dyn Db, other: Self) -> Self { + let generic_context = self.generic_context(db); + assert!(other.generic_context(db) == generic_context); + let types = self + .types(db) + .into_iter() + .zip(other.types(db)) + .map(|(self_type, other_type)| match (self_type, other_type) { + (unknown, known) | (known, unknown) if unknown.is_unknown() => *known, + _ => UnionType::from_elements(db, [self_type, other_type]), + }) + .collect(); + Specialization::new(db, self.generic_context(db), types) + } + pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { let types = self.types(db).iter().map(|ty| ty.normalized(db)).collect(); Self::new(db, self.generic_context(db), types) @@ -146,3 +182,114 @@ impl<'db> Specialization<'db> { .map(|(_, ty)| *ty) } } + +/// Performs type inference between parameter annotations and argument types, producing a +/// specialization of a generic function. +pub(crate) struct SpecializationBuilder<'db> { + db: &'db dyn Db, + generic_context: GenericContext<'db>, + types: FxHashMap, UnionBuilder<'db>>, +} + +impl<'db> SpecializationBuilder<'db> { + pub(crate) fn new(db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self { + Self { + db, + generic_context, + types: FxHashMap::default(), + } + } + + pub(crate) fn build(mut self) -> Specialization<'db> { + let types = self + .generic_context + .variables(self.db) + .iter() + .map(|variable| { + self.types + .remove(variable) + .map(UnionBuilder::build) + .unwrap_or(variable.default_ty(self.db).unwrap_or(Type::unknown())) + }) + .collect(); + Specialization::new(self.db, self.generic_context, types) + } + + fn add_type_mapping(&mut self, typevar: TypeVarInstance<'db>, ty: Type<'db>) { + let builder = self + .types + .entry(typevar) + .or_insert_with(|| UnionBuilder::new(self.db)); + builder.add_in_place(ty); + } + + pub(crate) fn infer(&mut self, formal: Type<'db>, actual: Type<'db>) { + // If the actual type is already assignable to the formal type, then return without adding + // any new type mappings. (Note that if the formal type contains any typevars, this check + // will fail, since no non-typevar types are assignable to a typevar.) + // + // In particular, this handles a case like + // + // ```py + // def f[T](t: T | None): ... + // + // f(None) + // ``` + // + // without specializing `T` to `None`. + if actual.is_assignable_to(self.db, formal) { + return; + } + + match (formal, actual) { + (Type::TypeVar(typevar), _) => self.add_type_mapping(typevar, actual), + + (Type::Tuple(formal_tuple), Type::Tuple(actual_tuple)) => { + let formal_elements = formal_tuple.elements(self.db); + let actual_elements = actual_tuple.elements(self.db); + if formal_elements.len() == actual_elements.len() { + for (formal_element, actual_element) in + formal_elements.iter().zip(actual_elements) + { + self.infer(*formal_element, *actual_element); + } + } + } + + (Type::Union(formal), _) => { + // TODO: We haven't implemented a full unification solver yet. If typevars appear + // in multiple union elements, we ideally want to express that _only one_ of them + // needs to match, and that we should infer the smallest type mapping that allows + // that. + // + // For now, we punt on handling multiple typevar elements. Instead, if _precisely + // one_ union element _is_ a typevar (not _contains_ a typevar), then we go ahead + // and add a mapping between that typevar and the actual type. (Note that we've + // already handled above the case where the actual is assignable to a _non-typevar_ + // union element.) + let mut typevars = formal.iter(self.db).filter_map(|ty| match ty { + Type::TypeVar(typevar) => Some(*typevar), + _ => None, + }); + let typevar = typevars.next(); + let additional_typevars = typevars.next(); + if let (Some(typevar), None) = (typevar, additional_typevars) { + self.add_type_mapping(typevar, actual); + } + } + + (Type::Intersection(formal), _) => { + // The actual type must be assignable to every (positive) element of the + // formal intersection, so we must infer type mappings for each of them. (The + // actual type must also be disjoint from every negative element of the + // intersection, but that doesn't help us infer any type mappings.) + for positive in formal.iter_positive(self.db) { + self.infer(positive, actual); + } + } + + // TODO: Add more forms that we can structurally induct into: type[C], callables + _ => {} + } + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index b049ff57c0..59248f9f7b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -82,13 +82,13 @@ use crate::types::mro::MroErrorKind; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType, - DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias, GenericClass, - IntersectionBuilder, IntersectionType, KnownClass, KnownFunction, KnownInstanceType, - MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter, ParameterForm, Parameters, - Signature, Signatures, SliceLiteralType, StringLiteralType, SubclassOfType, Symbol, - SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType, TypeAndQualifiers, - TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, - UnionType, + ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias, + GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction, + KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter, + ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType, + SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType, + TypeAndQualifiers, TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints, + TypeVarInstance, UnionBuilder, UnionType, }; use crate::unpack::{Unpack, UnpackPosition}; use crate::util::subscript::{PyIndex, PySlice}; @@ -1478,6 +1478,10 @@ impl<'db> TypeInferenceBuilder<'db> { } } + let generic_context = type_params.as_ref().map(|type_params| { + GenericContext::from_type_params(self.db(), self.index, type_params) + }); + let function_kind = KnownFunction::try_from_definition_and_name(self.db(), definition, name); @@ -1494,6 +1498,7 @@ impl<'db> TypeInferenceBuilder<'db> { function_kind, body_scope, function_decorators, + generic_context, specialization, )); @@ -2582,14 +2587,15 @@ impl<'db> TypeInferenceBuilder<'db> { let result = object_ty.try_call_dunder_with_policy( db, "__setattr__", - CallArgumentTypes::positional([ + &mut CallArgumentTypes::positional([ Type::StringLiteral(StringLiteralType::new( db, Box::from(attribute), )), value_ty, ]), - MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, + MemberLookupPolicy::NO_INSTANCE_FALLBACK + | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, ); match result { @@ -4182,27 +4188,27 @@ impl<'db> TypeInferenceBuilder<'db> { let callable_type = self.infer_expression(func); // For class literals we model the entire class instantiation logic, so it is handled - // in a separate function. - let class = match callable_type { - Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() { - ClassBase::Dynamic(_) => None, - ClassBase::Class(class) => { - let (class_literal, _) = class.class_literal(self.db()); - Some(class_literal) - } - }, - Type::ClassLiteral(class) => Some(class), - _ => None, + // in a separate function. For some known classes we have manual signatures defined and use + // the `try_call` path below. + // TODO: it should be possible to move these special cases into the `try_call_constructor` + // path instead, or even remove some entirely once we support overloads fully. + let (call_constructor, known_class) = match callable_type { + Type::ClassLiteral(class) => (true, class.known(self.db())), + Type::GenericAlias(generic) => (true, ClassType::Generic(generic).known(self.db())), + Type::SubclassOf(subclass) => ( + true, + subclass + .subclass_of() + .into_class() + .and_then(|class| class.known(self.db())), + ), + _ => (false, None), }; - if class.is_some_and(|class| { - // For some known classes we have manual signatures defined and use the `try_call` path - // below. TODO: it should be possible to move these special cases into the - // `try_call_constructor` path instead, or even remove some entirely once we support - // overloads fully. - class.known(self.db()).is_none_or(|class| { - !matches!( - class, + if call_constructor + && !matches!( + known_class, + Some( KnownClass::Bool | KnownClass::Str | KnownClass::Type @@ -4210,8 +4216,8 @@ impl<'db> TypeInferenceBuilder<'db> { | KnownClass::Property | KnownClass::Super ) - }) - }) { + ) + { let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; let call_argument_types = self.infer_argument_types(arguments, call_arguments, &argument_forms); diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index b2d6bbf783..50c022e8d6 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -17,7 +17,7 @@ use smallvec::{smallvec, SmallVec}; use super::{definition_expression_type, DynamicType, Type}; use crate::semantic_index::definition::Definition; -use crate::types::generics::Specialization; +use crate::types::generics::{GenericContext, Specialization}; use crate::types::todo_type; use crate::Db; use ruff_python_ast::{self as ast, name::Name}; @@ -165,6 +165,7 @@ impl<'db> CallableSignature<'db> { /// Return a signature for a dynamic callable pub(crate) fn dynamic(signature_type: Type<'db>) -> Self { let signature = Signature { + generic_context: None, parameters: Parameters::gradual_form(), return_ty: Some(signature_type), }; @@ -176,6 +177,7 @@ impl<'db> CallableSignature<'db> { pub(crate) fn todo(reason: &'static str) -> Self { let signature_type = todo_type!(reason); let signature = Signature { + generic_context: None, parameters: Parameters::todo(), return_ty: Some(signature_type), }; @@ -210,6 +212,9 @@ impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> { /// The signature of one of the overloads of a callable. #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] pub struct Signature<'db> { + /// The generic context for this overload, if it is generic. + pub(crate) generic_context: Option>, + /// Parameters, in source order. /// /// The ordering of parameters in a valid signature must be: first positional-only parameters, @@ -227,6 +232,19 @@ pub struct Signature<'db> { impl<'db> Signature<'db> { pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option>) -> Self { Self { + generic_context: None, + parameters, + return_ty, + } + } + + pub(crate) fn new_generic( + generic_context: Option>, + parameters: Parameters<'db>, + return_ty: Option>, + ) -> Self { + Self { + generic_context, parameters, return_ty, } @@ -236,6 +254,7 @@ impl<'db> Signature<'db> { #[allow(unused_variables)] // 'reason' only unused in debug builds pub(crate) fn todo(reason: &'static str) -> Self { Signature { + generic_context: None, parameters: Parameters::todo(), return_ty: Some(todo_type!(reason)), } @@ -244,6 +263,7 @@ impl<'db> Signature<'db> { /// Return a typed signature from a function definition. pub(super) fn from_function( db: &'db dyn Db, + generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, ) -> Self { @@ -256,6 +276,7 @@ impl<'db> Signature<'db> { }); Self { + generic_context, parameters: Parameters::from_parameters( db, definition, @@ -283,6 +304,7 @@ impl<'db> Signature<'db> { pub(crate) fn bind_self(&self) -> Self { Self { + generic_context: self.generic_context, parameters: Parameters::new(self.parameters().iter().skip(1).cloned()), return_ty: self.return_ty, } diff --git a/crates/red_knot_python_semantic/src/types/slots.rs b/crates/red_knot_python_semantic/src/types/slots.rs index 2b265372ad..b86f2e2fa3 100644 --- a/crates/red_knot_python_semantic/src/types/slots.rs +++ b/crates/red_knot_python_semantic/src/types/slots.rs @@ -24,7 +24,8 @@ enum SlotsKind { impl SlotsKind { fn from(db: &dyn Db, base: ClassLiteralType) -> Self { - let Symbol::Type(slots_ty, bound) = base.own_class_member(db, "__slots__").symbol else { + let Symbol::Type(slots_ty, bound) = base.own_class_member(db, None, "__slots__").symbol + else { return Self::NotSpecified; };