diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 063c4af0b7..3a2c6423dd 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -392,12 +392,14 @@ impl Display for DisplayRepresentation<'_> { } } Type::ProtocolInstance(protocol) => match protocol.inner { - Protocol::FromClass(ClassType::NonGeneric(class)) => { - class.display_with(self.db, self.settings.clone()).fmt(f) - } - Protocol::FromClass(ClassType::Generic(alias)) => { - alias.display_with(self.db, self.settings.clone()).fmt(f) - } + Protocol::FromClass(class) => match *class { + ClassType::NonGeneric(class) => { + class.display_with(self.db, self.settings.clone()).fmt(f) + } + ClassType::Generic(alias) => { + alias.display_with(self.db, self.settings.clone()).fmt(f) + } + }, Protocol::Synthesized(synthetic) => { f.write_str(" SpecializationBuilder<'db> { // generic protocol, we will need to check the types of the protocol members to be // able to infer the specialization of the protocol that the class implements. Type::ProtocolInstance(ProtocolInstanceType { - inner: Protocol::FromClass(ClassType::Generic(alias)), + inner: Protocol::FromClass(class), .. - }) => Some(alias), + }) => class.into_generic_alias(), _ => None, }; diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 523f7fc796..04a560be83 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -10,7 +10,7 @@ use crate::semantic_index::definition::Definition; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::enums::is_single_member_enum; use crate::types::generics::{InferableTypeVars, walk_specialization}; -use crate::types::protocol_class::walk_protocol_interface; +use crate::types::protocol_class::{ProtocolClass, walk_protocol_interface}; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, ClassBase, ClassLiteral, FindLegacyTypeVarsVisitor, @@ -44,13 +44,17 @@ impl<'db> Type<'db> { .as_ref(), )), Some(KnownClass::Object) => Type::object(), - _ if class_literal.is_protocol(db) => { - Self::ProtocolInstance(ProtocolInstanceType::from_class(class)) - } - _ if class_literal.is_typed_dict(db) => Type::typed_dict(class), - // We don't call non_tuple_instance here because we've already checked that the class - // is not `object` - _ => Type::NominalInstance(NominalInstanceType(NominalInstanceInner::NonTuple(class))), + _ => class_literal + .is_typed_dict(db) + .then(|| Type::typed_dict(class)) + .or_else(|| { + class.into_protocol_class(db).map(|protocol_class| { + Self::ProtocolInstance(ProtocolInstanceType::from_class(protocol_class)) + }) + }) + .unwrap_or(Type::NominalInstance(NominalInstanceType( + NominalInstanceInner::NonTuple(class), + ))), } } @@ -601,7 +605,7 @@ pub(super) fn walk_protocol_instance_type<'db, V: super::visitor::TypeVisitor<'d impl<'db> ProtocolInstanceType<'db> { // Keep this method private, so that the only way of constructing `ProtocolInstanceType` // instances is through the `Type::instance` constructor function. - fn from_class(class: ClassType<'db>) -> Self { + fn from_class(class: ProtocolClass<'db>) -> Self { Self { inner: Protocol::FromClass(class), _phantom: PhantomData, @@ -625,7 +629,7 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn as_nominal_type(self) -> Option> { match self.inner { Protocol::FromClass(class) => { - Some(NominalInstanceType(NominalInstanceInner::NonTuple(class))) + Some(NominalInstanceType(NominalInstanceInner::NonTuple(*class))) } Protocol::Synthesized(_) => None, } @@ -805,11 +809,17 @@ impl<'db> VarianceInferable<'db> for ProtocolInstanceType<'db> { /// An enumeration of the two kinds of protocol types: those that originate from a class /// definition in source code, and those that are synthesized from a set of members. +/// +/// # Ordering +/// +/// Ordering between variants is stable and should be the same between runs. +/// Ordering within variants is based on the wrapped data's salsa-assigned id and not on its values. +/// The id may change between runs, or when e.g. a `Protocol` was garbage-collected and recreated. #[derive( Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord, get_size2::GetSize, )] pub(super) enum Protocol<'db> { - FromClass(ClassType<'db>), + FromClass(ProtocolClass<'db>), Synthesized(SynthesizedProtocolType<'db>), } @@ -817,10 +827,7 @@ impl<'db> Protocol<'db> { /// Return the members of this protocol type fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { match self { - Self::FromClass(class) => class - .into_protocol_class(db) - .expect("Class wrapped by `Protocol` should be a protocol class") - .interface(db), + Self::FromClass(class) => class.interface(db), Self::Synthesized(synthesized) => synthesized.interface(), } } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 6cb204231f..8e3835b386 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -42,7 +42,14 @@ impl<'db> ClassType<'db> { } /// Representation of a single `Protocol` class definition. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +/// +/// # Ordering +/// +/// Ordering is based on the wrapped data's salsa-assigned id and not on its values. +/// The id may change between runs, or when e.g. a `ProtocolClass` was garbage-collected and recreated. +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize, PartialOrd, Ord, +)] pub(super) struct ProtocolClass<'db>(ClassType<'db>); impl<'db> ProtocolClass<'db> { @@ -124,6 +131,19 @@ impl<'db> ProtocolClass<'db> { report_undeclared_protocol_member(context, first_definition, self, class_place_table); } } + + pub(super) fn apply_type_mapping_impl<'a>( + self, + db: &'db dyn Db, + type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, + visitor: &ApplyTypeMappingVisitor<'db>, + ) -> Self { + Self( + self.0 + .apply_type_mapping_impl(db, type_mapping, tcx, visitor), + ) + } } impl<'db> Deref for ProtocolClass<'db> { @@ -134,6 +154,12 @@ impl<'db> Deref for ProtocolClass<'db> { } } +impl<'db> From> for Type<'db> { + fn from(value: ProtocolClass<'db>) -> Self { + Self::from(value.0) + } +} + /// The interface of a protocol: the members of that protocol, and the types of those members. /// /// # Ordering diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 20bf9f322b..687d5aaad0 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -2,6 +2,7 @@ use crate::place::PlaceAndQualifiers; use crate::semantic_index::definition::Definition; use crate::types::constraints::ConstraintSet; use crate::types::generics::InferableTypeVars; +use crate::types::protocol_class::ProtocolClass; use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, @@ -289,6 +290,12 @@ impl<'db> From> for SubclassOfInner<'db> { } } +impl<'db> From> for SubclassOfInner<'db> { + fn from(value: ProtocolClass<'db>) -> Self { + SubclassOfInner::Class(*value) + } +} + impl<'db> From> for Type<'db> { fn from(value: SubclassOfInner<'db>) -> Self { match value {