diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index d188f269a8..ba4b45f9df 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -98,12 +98,10 @@ def deeper_list(x: list[set[str]]) -> None: reveal_type(takes_in_protocol(x)) # revealed: Unknown def deep_explicit(x: ExplicitlyImplements[str]) -> None: - # TODO: revealed: str - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: str def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None: - # TODO: revealed: set[str] - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: set[str] def takes_in_type(x: type[T]) -> type[T]: return x @@ -128,10 +126,8 @@ reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown class ExplicitSub(ExplicitlyImplements[int]): ... class ExplicitGenericSub(ExplicitlyImplements[T]): ... -# TODO: revealed: int -reveal_type(takes_in_protocol(ExplicitSub())) # revealed: Unknown -# TODO: revealed: str -reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: Unknown +reveal_type(takes_in_protocol(ExplicitSub())) # revealed: int +reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: str ``` ## Inferring a bound typevar diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 11701f17e7..ffcd04a78e 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -93,12 +93,10 @@ def deeper_list(x: list[set[str]]) -> None: reveal_type(takes_in_protocol(x)) # revealed: Unknown def deep_explicit(x: ExplicitlyImplements[str]) -> None: - # TODO: revealed: str - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: str def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None: - # TODO: revealed: set[str] - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: set[str] def takes_in_type[T](x: type[T]) -> type[T]: return x @@ -123,10 +121,8 @@ reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown class ExplicitSub(ExplicitlyImplements[int]): ... class ExplicitGenericSub[T](ExplicitlyImplements[T]): ... -# TODO: revealed: int -reveal_type(takes_in_protocol(ExplicitSub())) # revealed: Unknown -# TODO: revealed: str -reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: Unknown +reveal_type(takes_in_protocol(ExplicitSub())) # revealed: int +reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: str ``` ## Inferring a bound typevar diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 3eb740871c..94c3bfd9d7 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5333,7 +5333,7 @@ impl<'db> Type<'db> { Self::TypeVar(var) => Some(TypeDefinition::TypeVar(var.definition(db))), - Self::ProtocolInstance(protocol) => match protocol.inner() { + Self::ProtocolInstance(protocol) => match protocol.inner { Protocol::FromClass(class) => Some(TypeDefinition::Class(class.definition(db))), Protocol::Synthesized(_) => None, }, diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 467a577905..b2a150fe85 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -76,7 +76,7 @@ impl Display for DisplayRepresentation<'_> { (ClassType::Generic(alias), _) => alias.display(self.db).fmt(f), } } - Type::ProtocolInstance(protocol) => match protocol.inner() { + Type::ProtocolInstance(protocol) => match protocol.inner { Protocol::FromClass(ClassType::NonGeneric(class)) => { f.write_str(class.name(self.db)) } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index b5f2b3ef83..a6360e5732 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -4,7 +4,7 @@ use rustc_hash::FxHashMap; use crate::semantic_index::SemanticIndex; use crate::types::class::ClassType; use crate::types::class_base::ClassBase; -use crate::types::instance::NominalInstanceType; +use crate::types::instance::{NominalInstanceType, Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ declaration_type, todo_type, KnownInstanceType, Type, TypeVarBoundOrConstraints, @@ -630,7 +630,10 @@ impl<'db> SpecializationBuilder<'db> { // ``` // // without specializing `T` to `None`. - if !actual.is_never() && actual.is_subtype_of(self.db, formal) { + if !matches!(formal, Type::ProtocolInstance(_)) + && !actual.is_never() + && actual.is_subtype_of(self.db, formal) + { return Ok(()); } @@ -678,6 +681,14 @@ impl<'db> SpecializationBuilder<'db> { Type::NominalInstance(NominalInstanceType { class: ClassType::Generic(formal_alias), .. + }) + // TODO: This will only handle classes that explicit implement a generic protocol + // by listing it as a base class. To handle classes that implicitly implement a + // generic protocol, we will need to check the types of the protocol members to be + // able to infer the specialization of the protocol that the class implements. + | Type::ProtocolInstance(ProtocolInstanceType { + inner: Protocol::FromClass(ClassType::Generic(formal_alias)), + .. }), Type::NominalInstance(NominalInstanceType { class: actual_class, diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index b33d430609..99312849da 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -14,12 +14,9 @@ pub(super) use synthesized_protocol::SynthesizedProtocolType; impl<'db> Type<'db> { pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self { if class.class_literal(db).0.is_protocol(db) { - Self::ProtocolInstance(ProtocolInstanceType(Protocol::FromClass(class))) + Self::ProtocolInstance(ProtocolInstanceType::from_class(class)) } else { - Self::NominalInstance(NominalInstanceType { - class, - _phantom: PhantomData, - }) + Self::NominalInstance(NominalInstanceType::from_class(class)) } } @@ -34,9 +31,9 @@ impl<'db> Type<'db> { where M: IntoIterator)>, { - Self::ProtocolInstance(ProtocolInstanceType(Protocol::Synthesized( + Self::ProtocolInstance(ProtocolInstanceType::synthesized( SynthesizedProtocolType::new(db, ProtocolInterface::with_members(db, members)), - ))) + )) } /// Return `true` if `self` conforms to the interface described by `protocol`. @@ -51,7 +48,7 @@ impl<'db> Type<'db> { // TODO: this should consider the types of the protocol members // as well as whether each member *exists* on `self`. protocol - .0 + .inner .interface(db) .members(db) .all(|member| !self.member(db, member.name()).symbol.is_unbound()) @@ -69,6 +66,15 @@ pub struct NominalInstanceType<'db> { } impl<'db> NominalInstanceType<'db> { + // Keep this method private, so that the only way of constructing `NominalInstanceType` + // instances is through the `Type::instance` constructor function. + fn from_class(class: ClassType<'db>) -> Self { + Self { + class, + _phantom: PhantomData, + } + } + pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool { // N.B. The subclass relation is fully static self.class.is_subclass_of(db, other.class) @@ -131,10 +137,7 @@ impl<'db> NominalInstanceType<'db> { db: &'db dyn Db, type_mapping: TypeMapping<'a, 'db>, ) -> Self { - Self { - class: self.class.apply_type_mapping(db, type_mapping), - _phantom: PhantomData, - } + Self::from_class(self.class.apply_type_mapping(db, type_mapping)) } pub(super) fn find_legacy_typevars( @@ -155,21 +158,37 @@ impl<'db> From> for Type<'db> { /// A `ProtocolInstanceType` represents the set of all possible runtime objects /// that conform to the interface described by a certain protocol. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, salsa::Update)] -pub struct ProtocolInstanceType<'db>( +pub struct ProtocolInstanceType<'db> { + pub(super) inner: Protocol<'db>, + // Keep the inner field here private, // so that the only way of constructing `ProtocolInstanceType` instances // is through the `Type::instance` constructor function. - Protocol<'db>, -); + _phantom: PhantomData<()>, +} impl<'db> ProtocolInstanceType<'db> { - pub(super) fn inner(self) -> Protocol<'db> { - self.0 + // Keep this method private, so that the only way of constructing `ProtocolInstanceType` + // instances is through the `Type::instance` constructor function. + fn from_class(class: ClassType<'db>) -> Self { + Self { + inner: Protocol::FromClass(class), + _phantom: PhantomData, + } + } + + // Keep this method private, so that the only way of constructing `ProtocolInstanceType` + // instances is through the `Type::instance` constructor function. + fn synthesized(synthesized: SynthesizedProtocolType<'db>) -> Self { + Self { + inner: Protocol::Synthesized(synthesized), + _phantom: PhantomData, + } } /// Return the meta-type of this protocol-instance type. pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> { - match self.0 { + match self.inner { Protocol::FromClass(class) => SubclassOfType::from(db, class), // TODO: we can and should do better here. @@ -197,22 +216,22 @@ impl<'db> ProtocolInstanceType<'db> { if object.satisfies_protocol(db, self) { return object; } - match self.0 { - Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized( - SynthesizedProtocolType::new(db, self.0.interface(db)), - ))), + match self.inner { + Protocol::FromClass(_) => Type::ProtocolInstance(Self::synthesized( + SynthesizedProtocolType::new(db, self.inner.interface(db)), + )), Protocol::Synthesized(_) => Type::ProtocolInstance(self), } } /// Replace references to `class` with a self-reference marker pub(super) fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self { - match self.0 { + match self.inner { Protocol::FromClass(class_type) if class_type.class_literal(db).0 == class => { - ProtocolInstanceType(Protocol::Synthesized(SynthesizedProtocolType::new( + ProtocolInstanceType::synthesized(SynthesizedProtocolType::new( db, ProtocolInterface::SelfReference, - ))) + )) } _ => self, } @@ -220,12 +239,12 @@ impl<'db> ProtocolInstanceType<'db> { /// Return `true` if any of the members of this protocol type contain any `Todo` types. pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool { - self.0.interface(db).contains_todo(db) + self.inner.interface(db).contains_todo(db) } /// Return `true` if this protocol type is fully static. pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool { - self.0.interface(db).is_fully_static(db) + self.inner.interface(db).is_fully_static(db) } /// Return `true` if this protocol type is a subtype of the protocol `other`. @@ -238,9 +257,9 @@ impl<'db> ProtocolInstanceType<'db> { /// TODO: consider the types of the members as well as their existence pub(super) fn is_assignable_to(self, db: &'db dyn Db, other: Self) -> bool { other - .0 + .inner .interface(db) - .is_sub_interface_of(db, self.0.interface(db)) + .is_sub_interface_of(db, self.inner.interface(db)) } /// Return `true` if this protocol type is equivalent to the protocol `other`. @@ -269,7 +288,7 @@ impl<'db> ProtocolInstanceType<'db> { } pub(crate) fn instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { - match self.inner() { + match self.inner { Protocol::FromClass(class) => class.instance_member(db, name), Protocol::Synthesized(synthesized) => synthesized .interface() @@ -287,13 +306,13 @@ impl<'db> ProtocolInstanceType<'db> { db: &'db dyn Db, type_mapping: TypeMapping<'a, 'db>, ) -> Self { - match self.0 { - Protocol::FromClass(class) => Self(Protocol::FromClass( - class.apply_type_mapping(db, type_mapping), - )), - Protocol::Synthesized(synthesized) => Self(Protocol::Synthesized( - synthesized.apply_type_mapping(db, type_mapping), - )), + match self.inner { + Protocol::FromClass(class) => { + Self::from_class(class.apply_type_mapping(db, type_mapping)) + } + Protocol::Synthesized(synthesized) => { + Self::synthesized(synthesized.apply_type_mapping(db, type_mapping)) + } } } @@ -302,7 +321,7 @@ impl<'db> ProtocolInstanceType<'db> { db: &'db dyn Db, typevars: &mut FxOrderSet>, ) { - match self.0 { + match self.inner { Protocol::FromClass(class) => { class.find_legacy_typevars(db, typevars); }