diff --git a/crates/red_knot_python_semantic/resources/mdtest/function/return_type.md b/crates/red_knot_python_semantic/resources/mdtest/function/return_type.md index f3ce3f4b20..1365807e95 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/red_knot_python_semantic/resources/mdtest/function/return_type.md @@ -74,8 +74,6 @@ class Baz(Bar): T = TypeVar("T") class Qux(Protocol[T]): - # TODO: no error - # error: [invalid-return-type] def f(self) -> int: ... class Foo(Protocol): diff --git a/crates/red_knot_python_semantic/resources/mdtest/protocols.md b/crates/red_knot_python_semantic/resources/mdtest/protocols.md index ee375cd0cb..53f4d14b4f 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/protocols.md +++ b/crates/red_knot_python_semantic/resources/mdtest/protocols.md @@ -40,27 +40,63 @@ class Foo(Protocol, Protocol): ... # error: [inconsistent-mro] reveal_type(Foo.__mro__) # revealed: tuple[Literal[Foo], Unknown, Literal[object]] ``` +Protocols can also be generic, either by including `Generic[]` in the bases list, subscripting +`Protocol` directly in the bases list, using PEP-695 type parameters, or some combination of the +above: + +```py +from typing import TypeVar, Generic + +T = TypeVar("T") + +class Bar0(Protocol[T]): + x: T + +class Bar1(Protocol[T], Generic[T]): + x: T + +class Bar2[T](Protocol): + x: T + +class Bar3[T](Protocol[T]): + x: T +``` + +It's an error to include both bare `Protocol` and subscripted `Protocol[]` in the bases list +simultaneously: + +```py +# TODO: should emit a `[duplicate-bases]` error here: +class DuplicateBases(Protocol, Protocol[T]): + x: T + +# TODO: should not have `Generic` multiple times and `Protocol` multiple times +# revealed: tuple[Literal[DuplicateBases], typing.Protocol, typing.Generic, @Todo(`Protocol[]` subscript), @Todo(`Generic[]` subscript), Literal[object]] +reveal_type(DuplicateBases.__mro__) +``` + The introspection helper `typing(_extensions).is_protocol` can be used to verify whether a class is a protocol class or not: ```py from typing_extensions import is_protocol -# TODO: should be `Literal[True]` -reveal_type(is_protocol(MyProtocol)) # revealed: bool +reveal_type(is_protocol(MyProtocol)) # revealed: Literal[True] +reveal_type(is_protocol(Bar0)) # revealed: Literal[True] +reveal_type(is_protocol(Bar1)) # revealed: Literal[True] +reveal_type(is_protocol(Bar2)) # revealed: Literal[True] +reveal_type(is_protocol(Bar3)) # revealed: Literal[True] class NotAProtocol: ... -# TODO: should be `Literal[False]` -reveal_type(is_protocol(NotAProtocol)) # revealed: bool +reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False] ``` A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs -indicate that the argument passed in must be an instance of `type`. `Literal[False]` should be -inferred as the return type, however. +indicate that the argument passed in must be an instance of `type`. ```py -# TODO: the diagnostic is correct, but should infer `Literal[False]` +# We could also reasonably infer `Literal[False]` here, but it probably doesn't matter that much: # error: [invalid-argument-type] reveal_type(is_protocol("not a class")) # revealed: bool ``` @@ -74,8 +110,7 @@ class SubclassOfMyProtocol(MyProtocol): ... # revealed: tuple[Literal[SubclassOfMyProtocol], Literal[MyProtocol], typing.Protocol, typing.Generic, Literal[object]] reveal_type(SubclassOfMyProtocol.__mro__) -# TODO: should be `Literal[False]` -reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: bool +reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: Literal[False] ``` A protocol class may inherit from other protocols, however, as long as it re-inherits from @@ -84,8 +119,7 @@ A protocol class may inherit from other protocols, however, as long as it re-inh ```py class SubProtocol(MyProtocol, Protocol): ... -# TODO: should be `Literal[True]` -reveal_type(is_protocol(SubProtocol)) # revealed: bool +reveal_type(is_protocol(SubProtocol)) # revealed: Literal[True] class OtherProtocol(Protocol): some_attribute: str @@ -95,8 +129,7 @@ class ComplexInheritance(SubProtocol, OtherProtocol, Protocol): ... # revealed: tuple[Literal[ComplexInheritance], Literal[SubProtocol], Literal[MyProtocol], Literal[OtherProtocol], typing.Protocol, typing.Generic, Literal[object]] reveal_type(ComplexInheritance.__mro__) -# TODO: should be `Literal[True]` -reveal_type(is_protocol(ComplexInheritance)) # revealed: bool +reveal_type(is_protocol(ComplexInheritance)) # revealed: Literal[True] ``` If `Protocol` is present in the bases tuple, all other bases in the tuple must be protocol classes, @@ -134,6 +167,8 @@ reveal_type(Fine.__mro__) # revealed: tuple[Literal[Fine], typing.Protocol, typ class StillFine(Protocol, Generic[T], object): ... class EvenThis[T](Protocol, object): ... +class OrThis(Protocol[T], Generic[T]): ... +class AndThis(Protocol[T], Generic[T], object): ... ``` And multiple inheritance from a mix of protocol and non-protocol classes is fine as long as @@ -150,8 +185,7 @@ But if `Protocol` is not present in the bases list, the resulting class doesn't class anymore: ```py -# TODO: should reveal `Literal[False]` -reveal_type(is_protocol(FineAndDandy)) # revealed: bool +reveal_type(is_protocol(FineAndDandy)) # revealed: Literal[False] ``` A class does not *have* to inherit from a protocol class in order for it to be considered a subtype @@ -230,9 +264,10 @@ class Foo(typing.Protocol): class Bar(typing_extensions.Protocol): x: int -# TODO: these should pass -static_assert(typing_extensions.is_protocol(Foo)) # error: [static-assert-error] -static_assert(typing_extensions.is_protocol(Bar)) # error: [static-assert-error] +static_assert(typing_extensions.is_protocol(Foo)) +static_assert(typing_extensions.is_protocol(Bar)) + +# TODO: should pass static_assert(is_equivalent_to(Foo, Bar)) # error: [static-assert-error] ``` @@ -247,9 +282,10 @@ class RuntimeCheckableFoo(typing.Protocol): class RuntimeCheckableBar(typing_extensions.Protocol): x: int -# TODO: these should pass -static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo)) # error: [static-assert-error] -static_assert(typing_extensions.is_protocol(RuntimeCheckableBar)) # error: [static-assert-error] +static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo)) +static_assert(typing_extensions.is_protocol(RuntimeCheckableBar)) + +# TODO: should pass static_assert(is_equivalent_to(RuntimeCheckableFoo, RuntimeCheckableBar)) # error: [static-assert-error] # These should not error because the protocols are decorated with `@runtime_checkable` diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 7232a08f8d..a577dba34b 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -535,6 +535,15 @@ impl<'db> Bindings<'db> { } } + Some(KnownFunction::IsProtocol) => { + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral( + ty.into_class_literal() + .is_some_and(|class| class.is_protocol(db)), + )); + } + } + Some(KnownFunction::Overload) => { // TODO: This can be removed once we understand legacy generics because the // typeshed definition for `typing.overload` is an identity function. diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 8865a60950..23a4a1fde6 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -582,6 +582,17 @@ impl<'db> ClassLiteralType<'db> { .collect() } + /// Determine if this class is a protocol. + pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool { + self.explicit_bases(db).iter().any(|base| { + matches!( + base, + Type::KnownInstance(KnownInstanceType::Protocol) + | Type::Dynamic(DynamicType::SubscriptedProtocol) + ) + }) + } + /// Return the types of the decorators on this class #[salsa::tracked(return_ref)] fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> { diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 125655fff9..1733de42e4 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -81,9 +81,9 @@ use crate::types::generics::GenericContext; use crate::types::mro::MroErrorKind; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType, - ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias, - GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction, + binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class, + ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, + GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter, ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType, @@ -1224,7 +1224,7 @@ impl<'db> TypeInferenceBuilder<'db> { /// Returns `true` if the current scope is the function body scope of a method of a protocol /// (that is, a class which directly inherits `typing.Protocol`.) - fn in_class_that_inherits_protocol_directly(&self) -> bool { + fn in_protocol_class(&self) -> bool { let current_scope_id = self.scope().file_scope_id(self.db()); let current_scope = self.index.scope(current_scope_id); let Some(parent_scope_id) = current_scope.parent() else { @@ -1252,13 +1252,13 @@ impl<'db> TypeInferenceBuilder<'db> { return false; }; - // TODO move this to `Class` once we add proper `Protocol` support - node_ref.bases().iter().any(|base| { - matches!( - self.file_expression_type(base), - Type::KnownInstance(KnownInstanceType::Protocol) - ) - }) + let class_definition = self.index.expect_single_definition(node_ref.node()); + + let Type::ClassLiteral(class) = binding_type(self.db(), class_definition) else { + return false; + }; + + class.is_protocol(self.db()) } /// Returns `true` if the current scope is the function body scope of a function overload (that @@ -1322,7 +1322,7 @@ impl<'db> TypeInferenceBuilder<'db> { if (self.in_stub() || self.in_function_overload_or_abstractmethod() - || self.in_class_that_inherits_protocol_directly()) + || self.in_protocol_class()) && self.return_types_and_ranges.is_empty() && is_stub_suite(&function.body) { @@ -1625,7 +1625,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } else if (self.in_stub() || self.in_function_overload_or_abstractmethod() - || self.in_class_that_inherits_protocol_directly()) + || self.in_protocol_class()) && default .as_ref() .is_some_and(|d| d.is_ellipsis_literal_expr())