improve generic call inference performance

This commit is contained in:
Ibraheem Ahmed 2025-12-11 10:15:03 -05:00
parent caed80df5e
commit 531ca7e47a
2 changed files with 37 additions and 31 deletions

View File

@ -1003,7 +1003,7 @@ impl<'db> Type<'db> {
pub(crate) fn class_specialization(
self,
db: &'db dyn Db,
) -> Option<(ClassType<'db>, Specialization<'db>)> {
) -> Option<(ClassLiteral<'db>, Specialization<'db>)> {
self.specialization_of_optional(db, None)
}
@ -1021,7 +1021,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
expected_class: Option<ClassLiteral<'_>>,
) -> Option<(ClassType<'db>, Specialization<'db>)> {
) -> Option<(ClassLiteral<'db>, Specialization<'db>)> {
let class_type = match self {
Type::NominalInstance(instance) => instance,
Type::ProtocolInstance(instance) => instance.to_nominal_instance()?,
@ -1035,7 +1035,7 @@ impl<'db> Type<'db> {
return None;
}
Some((class_type, specialization?))
Some((class_literal, specialization?))
}
/// Given a type variable `T` from the generic context of a class `C`:
@ -3905,17 +3905,20 @@ impl<'db> Type<'db> {
.zip(specialization.types(db))
{
let variance = type_var.variance_with_polarity(db, polarity);
let tcx = tcx.and_then(|tcx| {
tcx.filter_union(db, |ty| {
ty.find_type_var_from(db, type_var, class_literal).is_some()
})
.find_type_var_from(db, type_var, class_literal)
let narrowed_tcx = tcx.and_then(|annotation| match annotation {
Type::Union(union) => union
.elements(db)
.iter()
.filter_map(|ty| ty.find_type_var_from(db, type_var, class_literal))
.exactly_one()
.ok(),
_ => annotation.find_type_var_from(db, type_var, class_literal),
});
f(type_var, *ty, variance, tcx);
f(type_var, *ty, variance, narrowed_tcx);
visitor.visit(*ty, || {
ty.visit_specialization_impl(db, tcx, variance, f, visitor);
ty.visit_specialization_impl(db, narrowed_tcx, variance, f, visitor);
});
}
}

View File

@ -2818,35 +2818,38 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
let tcx = tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
tcx.class_specialization(self.db)?;
let return_specialization = return_ty
let (tcx_class, tcx_specialization) = tcx
.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
.class_specialization(self.db);
.class_specialization(self.db)?;
let Some((return_class, return_specialization)) = return_ty
.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
.class_specialization(self.db)
else {
builder.infer(return_ty, tcx).ok()?;
return Some(builder.type_mappings().clone());
};
// TODO: We should use the constraint solver here to determine the type mappings for more
// complex subtyping relationships, e.g., callables, protocols, or unions containing multiple
// generic elements.
if let Some((class_literal, _)) = return_specialization
&& let Some(generic_alias) = class_literal.into_generic_alias()
{
let specialization = generic_alias.specialization(self.db);
for (class_type_var, return_ty) in specialization
.generic_context(self.db)
.variables(self.db)
.zip(specialization.types(self.db))
{
if let Some(ty) = tcx.find_type_var_from(
self.db,
class_type_var,
generic_alias.origin(self.db),
for base in return_class.iter_mro(self.db, Some(return_specialization)) {
let Some((base_class, Some(base_specialization))) =
base.into_class().map(|class| class.class_literal(self.db))
else {
continue;
};
if base_class == tcx_class {
for (base_ty, tcx_ty) in std::iter::zip(
base_specialization.types(self.db),
tcx_specialization.types(self.db),
) {
builder.infer(*return_ty, ty).ok()?;
builder.infer(*base_ty, *tcx_ty).ok()?;
}
break;
}
} else {
builder.infer(return_ty, tcx).ok()?;
}
Some(builder.type_mappings().clone())