diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 652da03559..040a25cf7f 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -179,7 +179,7 @@ pub(crate) enum CodeGeneratorKind { } impl CodeGeneratorKind { - pub(crate) fn from_class(db: &dyn Db, class: ClassLiteral<'_>) -> Option { + pub(crate) fn from_class(db: &dyn Db, class: ClassSingletonType<'_>) -> Option { #[salsa::tracked( cycle_fn=code_generator_of_class_recover, cycle_initial=code_generator_of_class_initial, @@ -187,7 +187,7 @@ impl CodeGeneratorKind { )] fn code_generator_of_class<'db>( db: &'db dyn Db, - class: ClassLiteral<'db>, + class: ClassSingletonType<'db>, ) -> Option { if class.dataclass_params(db).is_some() || class @@ -209,7 +209,7 @@ impl CodeGeneratorKind { fn code_generator_of_class_initial( _db: &dyn Db, - _class: ClassLiteral<'_>, + _class: ClassSingletonType<'_>, ) -> Option { None } @@ -219,7 +219,7 @@ impl CodeGeneratorKind { _db: &dyn Db, _value: &Option, _count: u32, - _class: ClassLiteral<'_>, + _class: ClassSingletonType<'_>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } @@ -227,7 +227,7 @@ impl CodeGeneratorKind { code_generator_of_class(db, class) } - pub(super) fn matches(self, db: &dyn Db, class: ClassLiteral<'_>) -> bool { + pub(super) fn matches(self, db: &dyn Db, class: ClassSingletonType<'_>) -> bool { CodeGeneratorKind::from_class(db, class) == Some(self) } } @@ -1173,6 +1173,11 @@ impl<'db> ClassType<'db> { let (singleton, _) = self.class_singleton(db); singleton.explicit_bases(db) } + + pub(super) fn own_fields(self, db: &'db dyn Db) -> FxOrderMap> { + let (singleton, specialization) = self.class_singleton(db); + singleton.own_fields(db, specialization) + } } impl<'db> From> for ClassType<'db> { @@ -1246,7 +1251,7 @@ pub(crate) struct Field<'db> { salsa::Update, get_size2::GetSize, )] -enum ClassSingletonType<'db> { +pub(crate) enum ClassSingletonType<'db> { Literal(ClassLiteral<'db>), NewType(NewTypeClass<'db>), } @@ -1452,6 +1457,18 @@ impl<'db> ClassSingletonType<'db> { Self::NewType(new_type) => new_type.explicit_bases(db), } } + + pub(super) fn own_fields( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> FxOrderMap> { + match self { + Self::Literal(literal) => literal.own_fields(db, specialization), + // A NewType can't be specialized. + Self::NewType(new_type) => new_type.own_fields(db), + } + } } impl<'db> From> for Type<'db> { @@ -2150,7 +2167,7 @@ impl<'db> ClassLiteral<'db> { .with_qualifiers(TypeQualifiers::CLASS_VAR); } - if CodeGeneratorKind::NamedTuple.matches(db, self) { + if CodeGeneratorKind::NamedTuple.matches(db, self.into()) { if let Some(field) = self.own_fields(db, specialization).get(name) { let property_getter_signature = Signature::new( Parameters::new([Parameter::positional_only(Some(Name::new_static("self")))]), @@ -2211,7 +2228,7 @@ impl<'db> ClassLiteral<'db> { let has_dataclass_param = |param| dataclass_params.is_some_and(|params| params.contains(param)); - let field_policy = CodeGeneratorKind::from_class(db, self)?; + let field_policy = CodeGeneratorKind::from_class(db, self.into())?; let instance_ty = Type::instance(db, self.apply_optional_specialization(db, specialization)); @@ -2283,7 +2300,11 @@ impl<'db> ClassLiteral<'db> { if let Some(ref mut default_ty) = default_ty { *default_ty = default_ty - .try_call_dunder_get(db, Type::none(db), Type::ClassSingleton(self)) + .try_call_dunder_get( + db, + Type::none(db), + Type::ClassSingleton(self.into()), + ) .map(|(return_ty, _)| return_ty) .unwrap_or_else(Type::unknown); } @@ -3201,9 +3222,9 @@ impl<'db> ClassLiteral<'db> { /// Also, populates `visited_classes` with all base classes of `self`. fn is_cyclically_defined_recursive<'db>( db: &'db dyn Db, - class: ClassLiteral<'db>, - classes_on_stack: &mut IndexSet>, - visited_classes: &mut IndexSet>, + class: ClassSingletonType<'db>, + classes_on_stack: &mut IndexSet>, + visited_classes: &mut IndexSet>, ) -> bool { let mut result = false; for explicit_base in class.explicit_bases(db) { @@ -3233,9 +3254,10 @@ impl<'db> ClassLiteral<'db> { tracing::trace!("Class::inheritance_cycle: {}", self.name(db)); let visited_classes = &mut IndexSet::new(); - if !is_cyclically_defined_recursive(db, self, &mut IndexSet::new(), visited_classes) { + if !is_cyclically_defined_recursive(db, self.into(), &mut IndexSet::new(), visited_classes) + { None - } else if visited_classes.contains(&self) { + } else if visited_classes.contains(&ClassSingletonType::from(self)) { Some(InheritanceCycle::Participant) } else { Some(InheritanceCycle::Inherited) @@ -3274,7 +3296,7 @@ impl<'db> ClassLiteral<'db> { impl<'db> From> for Type<'db> { fn from(class: ClassLiteral<'db>) -> Type<'db> { - Type::ClassSingleton(class) + Type::ClassSingleton(class.into()) } } @@ -3399,6 +3421,10 @@ impl<'db> NewTypeClass<'db> { pub(super) fn explicit_bases(self, db: &'db dyn Db) -> &[Type<'db>] { self.parent(db).explicit_bases(db) } + + pub(super) fn own_fields(self, db: &'db dyn Db) -> FxOrderMap> { + self.parent(db).own_fields(db) + } } impl<'db> get_size2::GetSize for NewTypeClass<'_> {} @@ -4228,10 +4254,12 @@ impl KnownClass { db: &'db dyn Db, specialization: impl IntoIterator>, ) -> Option> { - let Type::ClassSingleton(singleton) = self.to_class_singleton(db) else { + let Type::ClassSingleton(ClassSingletonType::Literal(class_literal)) = + self.to_class_singleton(db) + else { return None; }; - let generic_context = singleton.generic_context(db)?; + let generic_context = class_literal.generic_context(db)?; let types = specialization.into_iter().collect::>(); if types.len() != generic_context.len(db) { @@ -4245,10 +4273,10 @@ impl KnownClass { self.display(db) ); } - return Some(singleton.default_specialization(db)); + return Some(class_literal.default_specialization(db)); } - Some(singleton.apply_specialization(db, |_| generic_context.specialize(db, types))) + Some(class_literal.apply_specialization(db, |_| generic_context.specialize(db, types))) } /// Lookup a [`KnownClass`] in typeshed and return a [`Type`]