From d6dcc377f768c7a6fbe55206ca32c50223d67d59 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 1 Apr 2025 19:30:06 +0100 Subject: [PATCH] [red-knot] Flatten `Type::Callable` into four `Type` variants (#17126) ## Summary Currently our `Type::Callable` wraps a four-variant `CallableType` enum. But as time has gone on, I think we've found that the four variants in `CallableType` are really more different to each other than they are similar to each other: - `GeneralCallableType` is a structural type describing all callable types with a certain signature, but the other three types are "literal types", more similar to the `FunctionLiteral` variant - `GeneralCallableType` is not a singleton or a single-valued type, but the other three are all single-valued types (`WrapperDescriptorDunderGet` is even a singleton type) - `GeneralCallableType` has (or should have) ambiguous truthiness, but all possible inhabitants of the other three types are always truthy. - As a structural type, `GeneralCallableType` can contain inner unions and intersections that must be sorted in some contexts in our internal model, but this is not true for the other three variants. This PR flattens `Type::Callable` into four distinct `Type::` variants. In the process, it fixes a number of latent bugs that were concealed by the current architecture but are laid bare by the refactor. Unit tests for these bugs are included in the PR. --- .../type_properties/is_equivalent_to.md | 11 + .../type_properties/is_single_valued.md | 10 +- .../mdtest/type_properties/is_singleton.md | 26 ++ .../mdtest/type_properties/truthiness.md | 23 + crates/red_knot_python_semantic/src/types.rs | 411 +++++++++--------- .../src/types/call/bind.rs | 76 ++-- .../src/types/class_base.rs | 3 + .../src/types/display.rs | 18 +- .../src/types/infer.rs | 29 +- .../types/property_tests/type_generation.rs | 8 +- .../src/types/signatures.rs | 33 ++ .../src/types/type_ordering.rs | 39 +- 12 files changed, 384 insertions(+), 303 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index 9a7538f0f5..b2606cf434 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -240,4 +240,15 @@ static_assert(not is_equivalent_to(CallableTypeOf[f12], CallableTypeOf[f13])) static_assert(not is_equivalent_to(CallableTypeOf[f13], CallableTypeOf[f12])) ``` +### Unions containing `Callable`s containing unions + +Differently ordered unions inside `Callable`s inside unions can still be equivalent: + +```py +from typing import Callable +from knot_extensions import is_equivalent_to, static_assert + +static_assert(is_equivalent_to(int | Callable[[int | str], None], Callable[[str | int], None] | int)) +``` + [the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_single_valued.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_single_valued.md index 645eade036..b6339cc6b0 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_single_valued.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_single_valued.md @@ -3,8 +3,9 @@ A type is single-valued iff it is not empty and all inhabitants of it compare equal. ```py +import types from typing_extensions import Any, Literal, LiteralString, Never, Callable -from knot_extensions import is_single_valued, static_assert +from knot_extensions import is_single_valued, static_assert, TypeOf static_assert(is_single_valued(None)) static_assert(is_single_valued(Literal[True])) @@ -25,4 +26,11 @@ static_assert(not is_single_valued(tuple[None, int])) static_assert(not is_single_valued(Callable[..., None])) static_assert(not is_single_valued(Callable[[int, str], None])) + +class A: + def method(self): ... + +static_assert(is_single_valued(TypeOf[A().method])) +static_assert(is_single_valued(TypeOf[types.FunctionType.__get__])) +static_assert(is_single_valued(TypeOf[A.method.__get__])) ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_singleton.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_singleton.md index b8c20badfa..a60efdfd9e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_singleton.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_singleton.md @@ -133,3 +133,29 @@ from knot_extensions import static_assert, is_singleton reveal_type(types.NotImplementedType) # revealed: Unknown | Literal[_NotImplementedType] static_assert(not is_singleton(types.NotImplementedType)) ``` + +### Callables + +We currently treat the type of `types.FunctionType.__get__` as a singleton type that has its own +dedicated variant in the `Type` enum. That variant should be understood as a singleton type, but the +similar variants `Type::BoundMethod` and `Type::MethodWrapperDunderGet` should not be; nor should +`Type::Callable` types. + +If we refactor `Type` in the future to get rid of some or all of these `Type` variants, the +assertion that the type of `types.FunctionType.__get__` is a singleton type does not necessarily +have to hold true; it's more of a unit test for our current implementation. + +```py +import types +from typing import Callable +from knot_extensions import static_assert, is_singleton, TypeOf + +class A: + def method(self): ... + +static_assert(is_singleton(TypeOf[types.FunctionType.__get__])) + +static_assert(not is_singleton(Callable[[], None])) +static_assert(not is_singleton(TypeOf[A().method])) +static_assert(not is_singleton(TypeOf[A.method.__get__])) +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/truthiness.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/truthiness.md index f9e741d33f..00e7c09539 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/truthiness.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/truthiness.md @@ -120,3 +120,26 @@ static_assert(is_subtype_of(typing.TypeAliasType, AlwaysTruthy)) static_assert(is_subtype_of(types.MethodWrapperType, AlwaysTruthy)) static_assert(is_subtype_of(types.WrapperDescriptorType, AlwaysTruthy)) ``` + +### `Callable` types always have ambiguous truthiness + +```py +from typing import Callable + +def f(x: Callable, y: Callable[[int], str]): + reveal_type(bool(x)) # revealed: bool + reveal_type(bool(y)) # revealed: bool +``` + +But certain callable single-valued types are known to be always truthy: + +```py +from types import FunctionType + +class A: + def method(self): ... + +reveal_type(bool(A().method)) # revealed: Literal[True] +reveal_type(bool(f.__get__)) # revealed: Literal[True] +reveal_type(bool(FunctionType.__get__)) # revealed: Literal[True] +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index d76923592f..f81d92a6e6 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -236,7 +236,35 @@ pub enum Type<'db> { Never, /// A specific function object FunctionLiteral(FunctionType<'db>), - /// A callable object + /// Represents a callable `instance.method` where `instance` is an instance of a class + /// and `method` is a method (of that class). + /// + /// See [`BoundMethodType`] for more information. + /// + /// TODO: consider replacing this with `Callable & Instance(MethodType)`? + /// I.e. if we have a method `def f(self, x: int) -> str`, and see it being called as + /// `instance.f`, we could partially apply (and check) the `instance` argument against + /// the `self` parameter, and return a `MethodType & Callable[[int], str]`. + /// One drawback would be that we could not show the bound instance when that type is displayed. + BoundMethod(BoundMethodType<'db>), + /// Represents the callable `f.__get__` where `f` is a function. + /// + /// TODO: consider replacing this with `Callable & types.MethodWrapperType` type? + /// Requires `Callable` to be able to represent overloads, e.g. `types.FunctionType.__get__` has + /// this behaviour when a method is accessed on a class vs an instance: + /// + /// ```txt + /// * (None, type) -> Literal[function_on_which_it_was_called] + /// * (object, type | None) -> BoundMethod[instance, function_on_which_it_was_called] + /// ``` + MethodWrapperDunderGet(FunctionType<'db>), + /// Represents the callable `FunctionType.__get__`. + /// + /// TODO: Similar to above, this could eventually be replaced by a generic `Callable` + /// type. We currently add this as a separate variant because `FunctionType.__get__` + /// is an overloaded method and we do not support `@overload` yet. + WrapperDescriptorDunderGet, + /// The type of an arbitrary callable object with a certain specified signature. Callable(CallableType<'db>), /// A specific module object ModuleLiteral(ModuleLiteralType<'db>), @@ -339,13 +367,11 @@ impl<'db> Type<'db> { | Self::LiteralString | Self::SliceLiteral(_) | Self::Dynamic(DynamicType::Unknown | DynamicType::Any) - | Self::Callable( - CallableType::BoundMethod(_) - | CallableType::WrapperDescriptorDunderGet - | CallableType::MethodWrapperDunderGet(_), - ) => false, + | Self::BoundMethod(_) + | Self::WrapperDescriptorDunderGet + | Self::MethodWrapperDunderGet(_) => false, - Self::Callable(CallableType::General(callable)) => { + Self::Callable(callable) => { let signature = callable.signature(db); signature.parameters().iter().any(|param| { param @@ -565,6 +591,9 @@ impl<'db> Type<'db> { Type::Intersection(intersection.to_sorted_intersection(db)) } Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions_and_intersections(db)), + Type::Callable(callable) => { + Type::Callable(callable.with_sorted_unions_and_intersections(db)) + } Type::LiteralString | Type::Instance(_) | Type::AlwaysFalsy @@ -576,7 +605,9 @@ impl<'db> Type<'db> { | Type::Dynamic(_) | Type::Never | Type::FunctionLiteral(_) - | Type::Callable(_) + | Type::MethodWrapperDunderGet(_) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::KnownInstance(_) @@ -703,27 +734,22 @@ impl<'db> Type<'db> { .is_subtype_of(db, target), // The same reasoning applies for these special callable types: - (Type::Callable(CallableType::BoundMethod(_)), _) => KnownClass::MethodType + (Type::BoundMethod(_), _) => KnownClass::MethodType .to_instance(db) .is_subtype_of(db, target), - (Type::Callable(CallableType::MethodWrapperDunderGet(_)), _) => { - KnownClass::WrapperDescriptorType - .to_instance(db) - .is_subtype_of(db, target) - } - (Type::Callable(CallableType::WrapperDescriptorDunderGet), _) => { - KnownClass::WrapperDescriptorType - .to_instance(db) - .is_subtype_of(db, target) + (Type::MethodWrapperDunderGet(_), _) => KnownClass::WrapperDescriptorType + .to_instance(db) + .is_subtype_of(db, target), + (Type::WrapperDescriptorDunderGet, _) => KnownClass::WrapperDescriptorType + .to_instance(db) + .is_subtype_of(db, target), + + (Type::Callable(self_callable), Type::Callable(other_callable)) => { + self_callable.is_subtype_of(db, other_callable) } - ( - Type::Callable(CallableType::General(self_callable)), - Type::Callable(CallableType::General(other_callable)), - ) => self_callable.is_subtype_of(db, other_callable), - - (Type::Callable(CallableType::General(_)), _) => { - // TODO: Implement subtyping between general callable types and other types like + (Type::Callable(_), _) => { + // TODO: Implement subtyping between callable types and other types like // function literals, bound methods, class literals, `type[]`, etc.) false } @@ -961,10 +987,9 @@ impl<'db> Type<'db> { ) } - ( - Type::Callable(CallableType::General(self_callable)), - Type::Callable(CallableType::General(target_callable)), - ) => self_callable.is_assignable_to(db, target_callable), + (Type::Callable(self_callable), Type::Callable(target_callable)) => { + self_callable.is_assignable_to(db, target_callable) + } (Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => { self_function_literal @@ -991,10 +1016,7 @@ impl<'db> Type<'db> { left.is_equivalent_to(db, right) } (Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right), - ( - Type::Callable(CallableType::General(left)), - Type::Callable(CallableType::General(right)), - ) => left.is_equivalent_to(db, right), + (Type::Callable(left), Type::Callable(right)) => left.is_equivalent_to(db, right), _ => self == other && self.is_fully_static(db) && other.is_fully_static(db), } } @@ -1053,10 +1075,9 @@ impl<'db> Type<'db> { first.is_gradual_equivalent_to(db, second) } - ( - Type::Callable(CallableType::General(first)), - Type::Callable(CallableType::General(second)), - ) => first.is_gradual_equivalent_to(db, second), + (Type::Callable(first), Type::Callable(second)) => { + first.is_gradual_equivalent_to(db, second) + } _ => false, } @@ -1113,11 +1134,9 @@ impl<'db> Type<'db> { | Type::BytesLiteral(..) | Type::SliceLiteral(..) | Type::FunctionLiteral(..) - | Type::Callable( - CallableType::BoundMethod(..) - | CallableType::MethodWrapperDunderGet(..) - | CallableType::WrapperDescriptorDunderGet, - ) + | Type::BoundMethod(..) + | Type::MethodWrapperDunderGet(..) + | Type::WrapperDescriptorDunderGet | Type::ModuleLiteral(..) | Type::ClassLiteral(..) | Type::KnownInstance(..)), @@ -1127,11 +1146,9 @@ impl<'db> Type<'db> { | Type::BytesLiteral(..) | Type::SliceLiteral(..) | Type::FunctionLiteral(..) - | Type::Callable( - CallableType::BoundMethod(..) - | CallableType::MethodWrapperDunderGet(..) - | CallableType::WrapperDescriptorDunderGet, - ) + | Type::BoundMethod(..) + | Type::MethodWrapperDunderGet(..) + | Type::WrapperDescriptorDunderGet | Type::ModuleLiteral(..) | Type::ClassLiteral(..) | Type::KnownInstance(..)), @@ -1146,7 +1163,9 @@ impl<'db> Type<'db> { | Type::BooleanLiteral(..) | Type::BytesLiteral(..) | Type::FunctionLiteral(..) - | Type::Callable(..) + | Type::BoundMethod(..) + | Type::MethodWrapperDunderGet(..) + | Type::WrapperDescriptorDunderGet | Type::IntLiteral(..) | Type::SliceLiteral(..) | Type::StringLiteral(..) @@ -1158,7 +1177,9 @@ impl<'db> Type<'db> { | Type::BooleanLiteral(..) | Type::BytesLiteral(..) | Type::FunctionLiteral(..) - | Type::Callable(..) + | Type::BoundMethod(..) + | Type::MethodWrapperDunderGet(..) + | Type::WrapperDescriptorDunderGet | Type::IntLiteral(..) | Type::SliceLiteral(..) | Type::StringLiteral(..) @@ -1187,7 +1208,9 @@ impl<'db> Type<'db> { | Type::BytesLiteral(..) | Type::SliceLiteral(..) | Type::FunctionLiteral(..) - | Type::Callable(..) + | Type::BoundMethod(..) + | Type::MethodWrapperDunderGet(..) + | Type::WrapperDescriptorDunderGet | Type::ModuleLiteral(..), ) | ( @@ -1198,7 +1221,9 @@ impl<'db> Type<'db> { | Type::BytesLiteral(..) | Type::SliceLiteral(..) | Type::FunctionLiteral(..) - | Type::Callable(..) + | Type::BoundMethod(..) + | Type::MethodWrapperDunderGet(..) + | Type::WrapperDescriptorDunderGet | Type::ModuleLiteral(..), Type::SubclassOf(_), ) => true, @@ -1303,32 +1328,20 @@ impl<'db> Type<'db> { !KnownClass::FunctionType.is_subclass_of(db, class) } - ( - Type::Callable(CallableType::BoundMethod(_)), - Type::Instance(InstanceType { class }), - ) - | ( - Type::Instance(InstanceType { class }), - Type::Callable(CallableType::BoundMethod(_)), - ) => !KnownClass::MethodType.is_subclass_of(db, class), + (Type::BoundMethod(_), other) | (other, Type::BoundMethod(_)) => KnownClass::MethodType + .to_instance(db) + .is_disjoint_from(db, other), - ( - Type::Callable(CallableType::MethodWrapperDunderGet(_)), - Type::Instance(InstanceType { class }), - ) - | ( - Type::Instance(InstanceType { class }), - Type::Callable(CallableType::MethodWrapperDunderGet(_)), - ) => !KnownClass::MethodWrapperType.is_subclass_of(db, class), + (Type::MethodWrapperDunderGet(_), other) | (other, Type::MethodWrapperDunderGet(_)) => { + KnownClass::MethodWrapperType + .to_instance(db) + .is_disjoint_from(db, other) + } - ( - Type::Callable(CallableType::WrapperDescriptorDunderGet), - Type::Instance(InstanceType { class }), - ) - | ( - Type::Instance(InstanceType { class }), - Type::Callable(CallableType::WrapperDescriptorDunderGet), - ) => !KnownClass::WrapperDescriptorType.is_subclass_of(db, class), + (Type::WrapperDescriptorDunderGet, other) + | (other, Type::WrapperDescriptorDunderGet) => KnownClass::WrapperDescriptorType + .to_instance(db) + .is_disjoint_from(db, other), (Type::ModuleLiteral(..), other @ Type::Instance(..)) | (other @ Type::Instance(..), Type::ModuleLiteral(..)) => { @@ -1367,9 +1380,8 @@ impl<'db> Type<'db> { instance.is_disjoint_from(db, KnownClass::Tuple.to_instance(db)) } - (Type::Callable(CallableType::General(_)), _) - | (_, Type::Callable(CallableType::General(_))) => { - // TODO: Implement disjointedness for general callable types + (Type::Callable(_), _) | (_, Type::Callable(_)) => { + // TODO: Implement disjointedness for callable types false } } @@ -1381,11 +1393,9 @@ impl<'db> Type<'db> { Type::Dynamic(_) => false, Type::Never | Type::FunctionLiteral(..) - | Type::Callable( - CallableType::BoundMethod(_) - | CallableType::MethodWrapperDunderGet(_) - | CallableType::WrapperDescriptorDunderGet, - ) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) | Type::ModuleLiteral(..) | Type::IntLiteral(_) | Type::BooleanLiteral(_) @@ -1425,7 +1435,7 @@ impl<'db> Type<'db> { .elements(db) .iter() .all(|elem| elem.is_fully_static(db)), - Type::Callable(CallableType::General(callable)) => callable.is_fully_static(db), + Type::Callable(callable) => callable.is_fully_static(db), } } @@ -1451,20 +1461,32 @@ impl<'db> Type<'db> { Type::SubclassOf(..) => false, Type::BooleanLiteral(_) | Type::FunctionLiteral(..) - | Type::Callable( - CallableType::BoundMethod(_) - | CallableType::MethodWrapperDunderGet(_) - | CallableType::WrapperDescriptorDunderGet, - ) + | Type::WrapperDescriptorDunderGet | Type::ClassLiteral(..) | Type::ModuleLiteral(..) | Type::KnownInstance(..) => true, - Type::Callable(CallableType::General(_)) => { - // A general callable type is never a singleton because for any given signature, + Type::Callable(_) => { + // A callable type is never a singleton because for any given signature, // there could be any number of distinct objects that are all callable with that // signature. false } + Type::BoundMethod(..) => { + // `BoundMethod` types are single-valued types, but not singleton types: + // ```pycon + // >>> class Foo: + // ... def bar(self): pass + // >>> f = Foo() + // >>> f.bar is f.bar + // False + // ``` + false + } + Type::MethodWrapperDunderGet(_) => { + // Just a special case of `BoundMethod` really + // (this variant represents `f.__get__`, where `f` is any function) + false + } Type::Instance(InstanceType { class }) => { class.known(db).is_some_and(KnownClass::is_singleton) } @@ -1500,11 +1522,9 @@ impl<'db> Type<'db> { pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool { match self { Type::FunctionLiteral(..) - | Type::Callable( - CallableType::BoundMethod(..) - | CallableType::MethodWrapperDunderGet(..) - | CallableType::WrapperDescriptorDunderGet, - ) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) | Type::ModuleLiteral(..) | Type::ClassLiteral(..) | Type::IntLiteral(..) @@ -1535,7 +1555,7 @@ impl<'db> Type<'db> { | Type::LiteralString | Type::AlwaysTruthy | Type::AlwaysFalsy - | Type::Callable(CallableType::General(_)) => false, + | Type::Callable(_) => false, } } @@ -1568,10 +1588,9 @@ impl<'db> Type<'db> { Type::ClassLiteral(class_literal @ ClassLiteralType { class }) => { match (class.known(db), name) { - (Some(KnownClass::FunctionType), "__get__") => Some( - Symbol::bound(Type::Callable(CallableType::WrapperDescriptorDunderGet)) - .into(), - ), + (Some(KnownClass::FunctionType), "__get__") => { + Some(Symbol::bound(Type::WrapperDescriptorDunderGet).into()) + } (Some(KnownClass::FunctionType), "__set__" | "__delete__") => { // Hard code this knowledge, as we look up `__set__` and `__delete__` on `FunctionType` often. Some(Symbol::Unbound.into()) @@ -1632,6 +1651,9 @@ impl<'db> Type<'db> { Type::FunctionLiteral(_) | Type::Callable(_) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) | Type::ModuleLiteral(_) | Type::KnownInstance(_) | Type::AlwaysTruthy @@ -1702,22 +1724,16 @@ impl<'db> Type<'db> { .to_instance(db) .instance_member(db, name), - Type::Callable(CallableType::BoundMethod(_)) => KnownClass::MethodType + Type::BoundMethod(_) => KnownClass::MethodType .to_instance(db) .instance_member(db, name), - Type::Callable(CallableType::MethodWrapperDunderGet(_)) => { - KnownClass::MethodWrapperType - .to_instance(db) - .instance_member(db, name) - } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => { - KnownClass::WrapperDescriptorType - .to_instance(db) - .instance_member(db, name) - } - Type::Callable(CallableType::General(_)) => { - KnownClass::Object.to_instance(db).instance_member(db, name) - } + Type::MethodWrapperDunderGet(_) => KnownClass::MethodWrapperType + .to_instance(db) + .instance_member(db, name), + Type::WrapperDescriptorDunderGet => KnownClass::WrapperDescriptorType + .to_instance(db) + .instance_member(db, name), + Type::Callable(_) => KnownClass::Object.to_instance(db).instance_member(db, name), Type::IntLiteral(_) => KnownClass::Int.to_instance(db).instance_member(db, name), Type::BooleanLiteral(_) => KnownClass::Bool.to_instance(db).instance_member(db, name), @@ -2038,18 +2054,17 @@ impl<'db> Type<'db> { Type::Dynamic(..) | Type::Never => Symbol::bound(self).into(), - Type::FunctionLiteral(function) if name == "__get__" => Symbol::bound(Type::Callable( - CallableType::MethodWrapperDunderGet(function), - )) - .into(), + Type::FunctionLiteral(function) if name == "__get__" => { + Symbol::bound(Type::MethodWrapperDunderGet(function)).into() + } Type::ClassLiteral(ClassLiteralType { class }) if name == "__get__" && class.is_known(db, KnownClass::FunctionType) => { - Symbol::bound(Type::Callable(CallableType::WrapperDescriptorDunderGet)).into() + Symbol::bound(Type::WrapperDescriptorDunderGet).into() } - Type::Callable(CallableType::BoundMethod(bound_method)) => match name_str { + Type::BoundMethod(bound_method) => match name_str { "__self__" => Symbol::bound(bound_method.self_instance(db)).into(), "__func__" => { Symbol::bound(Type::FunctionLiteral(bound_method.function(db))).into() @@ -2065,19 +2080,13 @@ impl<'db> Type<'db> { }) } }, - Type::Callable(CallableType::MethodWrapperDunderGet(_)) => { - KnownClass::MethodWrapperType - .to_instance(db) - .member(db, &name) - } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => { - KnownClass::WrapperDescriptorType - .to_instance(db) - .member(db, &name) - } - Type::Callable(CallableType::General(_)) => { - KnownClass::Object.to_instance(db).member(db, &name) - } + Type::MethodWrapperDunderGet(_) => KnownClass::MethodWrapperType + .to_instance(db) + .member(db, &name), + Type::WrapperDescriptorDunderGet => KnownClass::WrapperDescriptorType + .to_instance(db) + .member(db, &name), + Type::Callable(_) => KnownClass::Object.to_instance(db).member(db, &name), Type::Instance(InstanceType { class }) if matches!(name.as_str(), "major" | "minor") @@ -2243,21 +2252,31 @@ impl<'db> Type<'db> { allow_short_circuit: bool, ) -> Result> { let truthiness = match self { - Type::Dynamic(_) | Type::Never => Truthiness::Ambiguous, - Type::FunctionLiteral(_) => Truthiness::AlwaysTrue, - Type::Callable(_) => Truthiness::AlwaysTrue, - Type::ModuleLiteral(_) => Truthiness::AlwaysTrue, + Type::Dynamic(_) | Type::Never | Type::Callable(_) | Type::LiteralString => { + Truthiness::Ambiguous + } + + Type::FunctionLiteral(_) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) + | Type::ModuleLiteral(_) + | Type::SliceLiteral(_) + | Type::AlwaysTruthy => Truthiness::AlwaysTrue, + + Type::AlwaysFalsy => Truthiness::AlwaysFalse, + Type::ClassLiteral(ClassLiteralType { class }) => class .metaclass_instance_type(db) .try_bool_impl(db, allow_short_circuit)?, + Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() { ClassBase::Dynamic(_) => Truthiness::Ambiguous, ClassBase::Class(class) => { Type::class_literal(class).try_bool_impl(db, allow_short_circuit)? } }, - Type::AlwaysTruthy => Truthiness::AlwaysTrue, - Type::AlwaysFalsy => Truthiness::AlwaysFalse, + instance_ty @ Type::Instance(InstanceType { class }) => match class.known(db) { Some(known_class) => known_class.bool(), None => { @@ -2322,7 +2341,9 @@ impl<'db> Type<'db> { } } }, + Type::KnownInstance(known_instance) => known_instance.bool(), + Type::Union(union) => { let mut truthiness = None; let mut all_not_callable = true; @@ -2362,16 +2383,16 @@ impl<'db> Type<'db> { } truthiness.unwrap_or(Truthiness::Ambiguous) } + Type::Intersection(_) => { // TODO Truthiness::Ambiguous } + Type::IntLiteral(num) => Truthiness::from(*num != 0), Type::BooleanLiteral(bool) => Truthiness::from(*bool), Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()), - Type::LiteralString => Truthiness::Ambiguous, Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()), - Type::SliceLiteral(_) => Truthiness::AlwaysTrue, Type::Tuple(items) => Truthiness::from(!items.elements(db).is_empty()), }; @@ -2434,18 +2455,19 @@ impl<'db> Type<'db> { /// [`CallErrorKind::NotCallable`]. fn signatures(self, db: &'db dyn Db) -> Signatures<'db> { match self { - Type::Callable(CallableType::General(callable)) => Signatures::single( - CallableSignature::single(self, callable.signature(db).clone()), - ), + Type::Callable(callable) => Signatures::single(CallableSignature::single( + self, + callable.signature(db).clone(), + )), - Type::Callable(CallableType::BoundMethod(bound_method)) => { + Type::BoundMethod(bound_method) => { let signature = bound_method.function(db).signature(db); let signature = CallableSignature::single(self, signature.clone()) .with_bound_type(bound_method.self_instance(db)); Signatures::single(signature) } - Type::Callable(CallableType::MethodWrapperDunderGet(_)) => { + Type::MethodWrapperDunderGet(_) => { // Here, we dynamically model the overloaded function signature of `types.FunctionType.__get__`. // This is required because we need to return more precise types than what the signature in // typeshed provides: @@ -2490,7 +2512,7 @@ impl<'db> Type<'db> { Signatures::single(signature) } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => { + Type::WrapperDescriptorDunderGet => { // Here, we also model `types.FunctionType.__get__`, but now we consider a call to // this as a function, i.e. we also expect the `self` argument to be passed in. @@ -2974,6 +2996,9 @@ impl<'db> Type<'db> { | Type::BytesLiteral(_) | Type::FunctionLiteral(_) | Type::Callable(..) + | Type::MethodWrapperDunderGet(_) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet | Type::Instance(_) | Type::KnownInstance(_) | Type::ModuleLiteral(_) @@ -3035,6 +3060,9 @@ impl<'db> Type<'db> { | Type::StringLiteral(_) | Type::Tuple(_) | Type::Callable(_) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) | Type::Never | Type::FunctionLiteral(_) => Err(InvalidTypeExpressionError { invalid_expressions: smallvec::smallvec![InvalidTypeExpression::InvalidType(*self)], @@ -3069,9 +3097,7 @@ impl<'db> Type<'db> { KnownInstanceType::TypeVar(_) => Ok(*self), // TODO: Use an opt-in rule for a bare `Callable` - KnownInstanceType::Callable => Ok(Type::Callable(CallableType::General( - GeneralCallableType::unknown(db), - ))), + KnownInstanceType::Callable => Ok(Type::Callable(CallableType::unknown(db))), KnownInstanceType::TypingSelf => Ok(todo_type!("Support for `typing.Self`")), KnownInstanceType::TypeAlias => Ok(todo_type!("Support for `typing.TypeAlias`")), @@ -3239,16 +3265,12 @@ impl<'db> Type<'db> { Type::SliceLiteral(_) => KnownClass::Slice.to_class_literal(db), Type::IntLiteral(_) => KnownClass::Int.to_class_literal(db), Type::FunctionLiteral(_) => KnownClass::FunctionType.to_class_literal(db), - Type::Callable(CallableType::BoundMethod(_)) => { - KnownClass::MethodType.to_class_literal(db) - } - Type::Callable(CallableType::MethodWrapperDunderGet(_)) => { - KnownClass::MethodWrapperType.to_class_literal(db) - } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => { + Type::BoundMethod(_) => KnownClass::MethodType.to_class_literal(db), + Type::MethodWrapperDunderGet(_) => KnownClass::MethodWrapperType.to_class_literal(db), + Type::WrapperDescriptorDunderGet => { KnownClass::WrapperDescriptorType.to_class_literal(db) } - Type::Callable(CallableType::General(_)) => KnownClass::Type.to_instance(db), + Type::Callable(_) => KnownClass::Type.to_instance(db), Type::ModuleLiteral(_) => KnownClass::ModuleType.to_class_literal(db), Type::Tuple(_) => KnownClass::Tuple.to_class_literal(db), Type::ClassLiteral(ClassLiteralType { class }) => class.metaclass(db), @@ -4204,10 +4226,7 @@ impl<'db> FunctionType<'db> { /// /// This powers the `CallableTypeOf` special form from the `knot_extensions` module. pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { - Type::Callable(CallableType::General(GeneralCallableType::new( - db, - self.signature(db).clone(), - ))) + Type::Callable(CallableType::new(db, self.signature(db).clone())) } /// Typed externally-visible signature for this function. @@ -4387,23 +4406,38 @@ pub struct BoundMethodType<'db> { self_instance: Type<'db>, } -/// This type represents a general callable type that are used to represent `typing.Callable` -/// and `lambda` expressions. +/// This type represents the set of all callable objects with a certain signature. +/// It can be written in type expressions using `typing.Callable`. +/// `lambda` expressions are inferred directly as `CallableType`s; all function-literal types +/// are subtypes of a `CallableType`. #[salsa::interned(debug)] -pub struct GeneralCallableType<'db> { +pub struct CallableType<'db> { #[return_ref] signature: Signature<'db>, } -impl<'db> GeneralCallableType<'db> { - /// Create a general callable type which accepts any parameters and returns an `Unknown` type. +impl<'db> CallableType<'db> { + /// Create a callable type which accepts any parameters and returns an `Unknown` type. pub(crate) fn unknown(db: &'db dyn Db) -> Self { - GeneralCallableType::new( + CallableType::new( db, Signature::new(Parameters::unknown(), Some(Type::unknown())), ) } + fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self { + let signature = self.signature(db); + let parameters = signature + .parameters() + .iter() + .map(|param| param.clone().with_sorted_unions_and_intersections(db)) + .collect(); + let return_ty = signature + .return_ty + .map(|return_ty| return_ty.with_sorted_unions_and_intersections(db)); + CallableType::new(db, Signature::new(parameters, return_ty)) + } + /// Returns `true` if this is a fully static callable type. /// /// A callable type is fully static if all of its parameters and return type are fully static @@ -4928,45 +4962,6 @@ impl<'db> GeneralCallableType<'db> { } } -/// A type that represents callable objects. -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update)] -pub enum CallableType<'db> { - /// Represents a general callable type. - General(GeneralCallableType<'db>), - - /// Represents a callable `instance.method` where `instance` is an instance of a class - /// and `method` is a method (of that class). - /// - /// See [`BoundMethodType`] for more information. - /// - /// TODO: This could eventually be replaced by a more general `Callable` type, if we - /// decide to bind the first argument of method calls early, i.e. if we have a method - /// `def f(self, x: int) -> str`, and see it being called as `instance.f`, we could - /// partially apply (and check) the `instance` argument against the `self` parameter, - /// and return a `Callable[[int], str]`. One drawback would be that we could not show - /// the bound instance when that type is displayed. - BoundMethod(BoundMethodType<'db>), - - /// Represents the callable `f.__get__` where `f` is a function. - /// - /// TODO: This could eventually be replaced by a more general `Callable` type that is - /// also able to represent overloads. It would need to represent the two overloads of - /// `types.FunctionType.__get__`: - /// - /// ```txt - /// * (None, type) -> Literal[function_on_which_it_was_called] - /// * (object, type | None) -> BoundMethod[instance, function_on_which_it_was_called] - /// ``` - MethodWrapperDunderGet(FunctionType<'db>), - - /// Represents the callable `FunctionType.__get__`. - /// - /// TODO: Similar to above, this could eventually be replaced by a generic `Callable` - /// type. We currently add this as a separate variant because `FunctionType.__get__` - /// is an overloaded method and we do not support `@overload` yet. - WrapperDescriptorDunderGet, -} - #[salsa::interned(debug)] pub struct ModuleLiteralType<'db> { /// The file in which this module was imported. diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index f38851ac68..ee1a829348 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -18,8 +18,8 @@ use crate::types::diagnostic::{ }; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ - todo_type, BoundMethodType, CallableType, ClassLiteralType, KnownClass, KnownFunction, - KnownInstanceType, UnionType, + todo_type, BoundMethodType, ClassLiteralType, KnownClass, KnownFunction, KnownInstanceType, + UnionType, }; use ruff_db::diagnostic::{OldSecondaryDiagnosticMessage, Span}; use ruff_python_ast as ast; @@ -210,26 +210,22 @@ impl<'db> Bindings<'db> { }; match binding_type { - Type::Callable(CallableType::MethodWrapperDunderGet(function)) => { + Type::MethodWrapperDunderGet(function) => { if function.has_known_class_decorator(db, KnownClass::Classmethod) && function.decorators(db).len() == 1 { match overload.parameter_types() { [_, Some(owner)] => { - overload.set_return_type(Type::Callable( - CallableType::BoundMethod(BoundMethodType::new( - db, function, *owner, - )), - )); + overload.set_return_type(Type::BoundMethod(BoundMethodType::new( + db, function, *owner, + ))); } [Some(instance), None] => { - overload.set_return_type(Type::Callable( - CallableType::BoundMethod(BoundMethodType::new( - db, - function, - instance.to_meta_type(db), - )), - )); + overload.set_return_type(Type::BoundMethod(BoundMethodType::new( + db, + function, + instance.to_meta_type(db), + ))); } _ => {} } @@ -237,14 +233,14 @@ impl<'db> Bindings<'db> { if first.is_none(db) { overload.set_return_type(Type::FunctionLiteral(function)); } else { - overload.set_return_type(Type::Callable(CallableType::BoundMethod( - BoundMethodType::new(db, function, *first), + overload.set_return_type(Type::BoundMethod(BoundMethodType::new( + db, function, *first, ))); } } } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => { + Type::WrapperDescriptorDunderGet => { if let [Some(function_ty @ Type::FunctionLiteral(function)), ..] = overload.parameter_types() { @@ -253,20 +249,18 @@ impl<'db> Bindings<'db> { { match overload.parameter_types() { [_, _, Some(owner)] => { - overload.set_return_type(Type::Callable( - CallableType::BoundMethod(BoundMethodType::new( - db, *function, *owner, - )), + overload.set_return_type(Type::BoundMethod( + BoundMethodType::new(db, *function, *owner), )); } [_, Some(instance), None] => { - overload.set_return_type(Type::Callable( - CallableType::BoundMethod(BoundMethodType::new( + overload.set_return_type(Type::BoundMethod( + BoundMethodType::new( db, *function, instance.to_meta_type(db), - )), + ), )); } @@ -308,10 +302,8 @@ impl<'db> Bindings<'db> { } [_, Some(instance), _] => { - overload.set_return_type(Type::Callable( - CallableType::BoundMethod(BoundMethodType::new( - db, *function, *instance, - )), + overload.set_return_type(Type::BoundMethod( + BoundMethodType::new(db, *function, *instance), )); } @@ -935,17 +927,15 @@ impl<'db> CallableDescription<'db> { kind: "class", name: class_type.class().name(db), }), - Type::Callable(CallableType::BoundMethod(bound_method)) => Some(CallableDescription { + Type::BoundMethod(bound_method) => Some(CallableDescription { kind: "bound method", name: bound_method.function(db).name(db), }), - Type::Callable(CallableType::MethodWrapperDunderGet(function)) => { - Some(CallableDescription { - kind: "method wrapper `__get__` of function", - name: function.name(db), - }) - } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => Some(CallableDescription { + Type::MethodWrapperDunderGet(function) => Some(CallableDescription { + kind: "method wrapper `__get__` of function", + name: function.name(db), + }), + Type::WrapperDescriptorDunderGet => Some(CallableDescription { kind: "wrapper descriptor", name: "FunctionType.__get__", }), @@ -1061,13 +1051,11 @@ impl<'db> BindingError<'db> { None } } - Type::Callable(CallableType::BoundMethod(bound_method)) => { - Self::parameter_span_from_index( - db, - Type::FunctionLiteral(bound_method.function(db)), - parameter_index, - ) - } + Type::BoundMethod(bound_method) => Self::parameter_span_from_index( + db, + Type::FunctionLiteral(bound_method.function(db)), + parameter_index, + ), _ => None, } } diff --git a/crates/red_knot_python_semantic/src/types/class_base.rs b/crates/red_knot_python_semantic/src/types/class_base.rs index e7606468ae..177ccd1dd8 100644 --- a/crates/red_knot_python_semantic/src/types/class_base.rs +++ b/crates/red_knot_python_semantic/src/types/class_base.rs @@ -73,6 +73,9 @@ impl<'db> ClassBase<'db> { | Type::BooleanLiteral(_) | Type::FunctionLiteral(_) | Type::Callable(..) + | Type::BoundMethod(_) + | Type::MethodWrapperDunderGet(_) + | Type::WrapperDescriptorDunderGet | Type::BytesLiteral(_) | Type::IntLiteral(_) | Type::StringLiteral(_) diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 0b31eb69e3..963da7871d 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -9,8 +9,8 @@ use ruff_python_literal::escape::AsciiEscape; use crate::types::class_base::ClassBase; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ - CallableType, ClassLiteralType, InstanceType, IntersectionType, KnownClass, StringLiteralType, - Type, UnionType, + ClassLiteralType, InstanceType, IntersectionType, KnownClass, StringLiteralType, Type, + UnionType, }; use crate::Db; use rustc_hash::FxHashMap; @@ -90,10 +90,8 @@ impl Display for DisplayRepresentation<'_> { }, Type::KnownInstance(known_instance) => f.write_str(known_instance.repr(self.db)), Type::FunctionLiteral(function) => f.write_str(function.name(self.db)), - Type::Callable(CallableType::General(callable)) => { - callable.signature(self.db).display(self.db).fmt(f) - } - Type::Callable(CallableType::BoundMethod(bound_method)) => { + Type::Callable(callable) => callable.signature(self.db).display(self.db).fmt(f), + Type::BoundMethod(bound_method) => { write!( f, "", @@ -101,14 +99,14 @@ impl Display for DisplayRepresentation<'_> { instance = bound_method.self_instance(self.db).display(self.db) ) } - Type::Callable(CallableType::MethodWrapperDunderGet(function)) => { + Type::MethodWrapperDunderGet(function) => { write!( f, "", function = function.name(self.db) ) } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => { + Type::WrapperDescriptorDunderGet => { f.write_str("") } Type::Union(union) => union.display(self.db).fmt(f), @@ -423,9 +421,7 @@ struct DisplayMaybeParenthesizedType<'db> { impl Display for DisplayMaybeParenthesizedType<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - if let Type::Callable(CallableType::General(_) | CallableType::MethodWrapperDunderGet(_)) = - self.ty - { + if let Type::Callable(_) | Type::MethodWrapperDunderGet(_) = self.ty { write!(f, "({})", self.ty.display(self.db)) } else { self.ty.display(self.db).fmt(f) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index c2b323ea89..b06b1fe96a 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -82,7 +82,7 @@ use crate::types::{ Truthiness, TupleType, Type, TypeAliasType, TypeAndQualifiers, TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, UnionType, }; -use crate::types::{CallableType, GeneralCallableType, Signature}; +use crate::types::{CallableType, Signature}; use crate::unpack::{Unpack, UnpackPosition}; use crate::util::subscript::{PyIndex, PySlice}; use crate::Db; @@ -2313,6 +2313,9 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::KnownInstance(..) | Type::FunctionLiteral(..) | Type::Callable(..) + | Type::BoundMethod(_) + | Type::MethodWrapperDunderGet(_) + | Type::WrapperDescriptorDunderGet | Type::AlwaysTruthy | Type::AlwaysFalsy => match object_ty.class_member(db, attribute.into()) { meta_attr @ SymbolAndQualifiers { .. } if meta_attr.is_class_var() => { @@ -3887,10 +3890,10 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: Useful inference of a lambda's return type will require a different approach, // which does the inference of the body expression based on arguments at each call site, // rather than eagerly computing a return type without knowing the argument types. - Type::Callable(CallableType::General(GeneralCallableType::new( + Type::Callable(CallableType::new( self.db(), Signature::new(parameters, Some(Type::unknown())), - ))) + )) } fn infer_call_expression(&mut self, call_expression: &ast::ExprCall) -> Type<'db> { @@ -4410,6 +4413,9 @@ impl<'db> TypeInferenceBuilder<'db> { op @ (ast::UnaryOp::UAdd | ast::UnaryOp::USub | ast::UnaryOp::Invert), Type::FunctionLiteral(_) | Type::Callable(..) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) + | Type::BoundMethod(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::SubclassOf(_) @@ -4658,6 +4664,9 @@ impl<'db> TypeInferenceBuilder<'db> { ( Type::FunctionLiteral(_) | Type::Callable(..) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::SubclassOf(_) @@ -4674,6 +4683,9 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::Tuple(_), Type::FunctionLiteral(_) | Type::Callable(..) + | Type::BoundMethod(_) + | Type::WrapperDescriptorDunderGet + | Type::MethodWrapperDunderGet(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::SubclassOf(_) @@ -6744,12 +6756,12 @@ impl<'db> TypeInferenceBuilder<'db> { let callable_type = if let (Some(parameters), Some(return_type), true) = (parameters, return_type, correct_argument_number) { - GeneralCallableType::new(db, Signature::new(parameters, Some(return_type))) + CallableType::new(db, Signature::new(parameters, Some(return_type))) } else { - GeneralCallableType::unknown(db) + CallableType::unknown(db) }; - let callable_type = Type::Callable(CallableType::General(callable_type)); + let callable_type = Type::Callable(callable_type); // `Signature` / `Parameters` are not a `Type` variant, so we're storing // the outer callable type on the these expressions instead. @@ -6839,10 +6851,7 @@ impl<'db> TypeInferenceBuilder<'db> { return Type::unknown(); }; - Type::Callable(CallableType::General(GeneralCallableType::new( - db, - signature.clone(), - ))) + Type::Callable(CallableType::new(db, signature.clone())) } }, diff --git a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs index aa5b65967f..a7d3c0a3df 100644 --- a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs +++ b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs @@ -1,8 +1,8 @@ use crate::db::tests::TestDb; use crate::symbol::{builtins_symbol, known_module_symbol}; use crate::types::{ - BoundMethodType, CallableType, IntersectionBuilder, KnownClass, KnownInstanceType, - SubclassOfType, TupleType, Type, UnionType, + BoundMethodType, IntersectionBuilder, KnownClass, KnownInstanceType, SubclassOfType, TupleType, + Type, UnionType, }; use crate::{Db, KnownModule}; use quickcheck::{Arbitrary, Gen}; @@ -53,11 +53,11 @@ fn create_bound_method<'db>( function: Type<'db>, builtins_class: Type<'db>, ) -> Type<'db> { - Type::Callable(CallableType::BoundMethod(BoundMethodType::new( + Type::BoundMethod(BoundMethodType::new( db, function.expect_function_literal(), builtins_class.to_instance(db).unwrap(), - ))) + )) } impl Ty { diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index a59a4a8165..0606372bec 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -504,6 +504,12 @@ impl<'db, 'a> IntoIterator for &'a Parameters<'db> { } } +impl<'db> FromIterator> for Parameters<'db> { + fn from_iter>>(iter: T) -> Self { + Self::new(iter) + } +} + impl<'db> std::ops::Index for Parameters<'db> { type Output = Parameter<'db>; @@ -593,6 +599,33 @@ impl<'db> Parameter<'db> { self } + pub(crate) fn with_sorted_unions_and_intersections(mut self, db: &'db dyn Db) -> Self { + self.annotated_type = self + .annotated_type + .map(|ty| ty.with_sorted_unions_and_intersections(db)); + + self.kind = match self.kind { + ParameterKind::PositionalOnly { name, default_type } => ParameterKind::PositionalOnly { + name, + default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)), + }, + ParameterKind::PositionalOrKeyword { name, default_type } => { + ParameterKind::PositionalOrKeyword { + name, + default_type: default_type + .map(|ty| ty.with_sorted_unions_and_intersections(db)), + } + } + ParameterKind::KeywordOnly { name, default_type } => ParameterKind::KeywordOnly { + name, + default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)), + }, + ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => self.kind, + }; + + self + } + fn from_node_and_kind( db: &'db dyn Db, definition: Definition<'db>, diff --git a/crates/red_knot_python_semantic/src/types/type_ordering.rs b/crates/red_knot_python_semantic/src/types/type_ordering.rs index 610c535663..596d132fd3 100644 --- a/crates/red_knot_python_semantic/src/types/type_ordering.rs +++ b/crates/red_knot_python_semantic/src/types/type_ordering.rs @@ -1,7 +1,6 @@ use std::cmp::Ordering; use crate::db::Db; -use crate::types::CallableType; use super::{ class_base::ClassBase, ClassLiteralType, DynamicType, InstanceType, KnownInstanceType, @@ -62,32 +61,22 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::FunctionLiteral(_), _) => Ordering::Less, (_, Type::FunctionLiteral(_)) => Ordering::Greater, - ( - Type::Callable(CallableType::BoundMethod(left)), - Type::Callable(CallableType::BoundMethod(right)), - ) => left.cmp(right), - (Type::Callable(CallableType::BoundMethod(_)), _) => Ordering::Less, - (_, Type::Callable(CallableType::BoundMethod(_))) => Ordering::Greater, + (Type::BoundMethod(left), Type::BoundMethod(right)) => left.cmp(right), + (Type::BoundMethod(_), _) => Ordering::Less, + (_, Type::BoundMethod(_)) => Ordering::Greater, - ( - Type::Callable(CallableType::MethodWrapperDunderGet(left)), - Type::Callable(CallableType::MethodWrapperDunderGet(right)), - ) => left.cmp(right), - (Type::Callable(CallableType::MethodWrapperDunderGet(_)), _) => Ordering::Less, - (_, Type::Callable(CallableType::MethodWrapperDunderGet(_))) => Ordering::Greater, - - ( - Type::Callable(CallableType::WrapperDescriptorDunderGet), - Type::Callable(CallableType::WrapperDescriptorDunderGet), - ) => Ordering::Equal, - (Type::Callable(CallableType::WrapperDescriptorDunderGet), _) => Ordering::Less, - (_, Type::Callable(CallableType::WrapperDescriptorDunderGet)) => Ordering::Greater, - - (Type::Callable(CallableType::General(_)), Type::Callable(CallableType::General(_))) => { - Ordering::Equal + (Type::MethodWrapperDunderGet(left), Type::MethodWrapperDunderGet(right)) => { + left.cmp(right) } - (Type::Callable(CallableType::General(_)), _) => Ordering::Less, - (_, Type::Callable(CallableType::General(_))) => Ordering::Greater, + (Type::MethodWrapperDunderGet(_), _) => Ordering::Less, + (_, Type::MethodWrapperDunderGet(_)) => Ordering::Greater, + + (Type::WrapperDescriptorDunderGet, _) => Ordering::Less, + (_, Type::WrapperDescriptorDunderGet) => Ordering::Greater, + + (Type::Callable(left), Type::Callable(right)) => left.cmp(right), + (Type::Callable(_), _) => Ordering::Less, + (_, Type::Callable(_)) => Ordering::Greater, (Type::Tuple(left), Type::Tuple(right)) => { debug_assert_eq!(*left, left.with_sorted_unions_and_intersections(db));