prefer declared type of generic classes

This commit is contained in:
Ibraheem Ahmed 2025-10-24 22:38:17 -04:00
parent 1d6ae8596a
commit 2f4dcbf651
7 changed files with 160 additions and 41 deletions

View File

@ -427,14 +427,13 @@ a = f("a")
reveal_type(a) # revealed: list[Literal["a"]] reveal_type(a) # revealed: list[Literal["a"]]
b: list[int | Literal["a"]] = f("a") b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[Literal["a"] | int] reveal_type(b) # revealed: list[int | Literal["a"]]
c: list[int | str] = f("a") c: list[int | str] = f("a")
reveal_type(c) # revealed: list[str | int] reveal_type(c) # revealed: list[int | str]
d: list[int | tuple[int, int]] = f((1, 2)) d: list[int | tuple[int, int]] = f((1, 2))
# TODO: We could avoid reordering the union elements here. reveal_type(d) # revealed: list[int | tuple[int, int]]
reveal_type(d) # revealed: list[tuple[int, int] | int]
e: list[int] = f(True) e: list[int] = f(True)
reveal_type(e) # revealed: list[int] reveal_type(e) # revealed: list[int]
@ -455,7 +454,54 @@ j: int | str = f2(True)
reveal_type(j) # revealed: Literal[True] reveal_type(j) # revealed: Literal[True]
``` ```
Types are not widened unnecessarily: ## Prefer the declared type of generic classes
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Any
def f[T](x: T) -> list[T]:
return [x]
def f2[T](x: T) -> list[T] | None:
return [x]
def f3[T](x: T) -> list[T] | dict[T, T]:
return [x]
a = f(1)
reveal_type(a) # revealed: list[Literal[1]]
b: list[Any] = f(1)
reveal_type(b) # revealed: list[Any]
c: list[Any] = [1]
reveal_type(c) # revealed: list[Any]
d: list[Any] | None = f(1)
reveal_type(d) # revealed: list[Any]
e: list[Any] | None = [1]
reveal_type(e) # revealed: list[Any]
f: list[Any] | None = f2(1)
reveal_type(f) # revealed: list[Any] | None
g: list[Any] | dict[Any, Any] = f3(1)
# TODO: Better constraint solver.
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
```
## Prefer the inferred type of non-generic classes
```toml
[environment]
python-version = "3.12"
```
```py ```py
def id[T](x: T) -> T: def id[T](x: T) -> T:

View File

@ -50,8 +50,6 @@ def _(l: list[int] | None = None):
def f[T](x: T, cond: bool) -> T | list[T]: def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x] return x if cond else [x]
# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
l5: int | list[int] = f(1, True) l5: int | list[int] = f(1, True)
``` ```

View File

@ -891,13 +891,24 @@ impl<'db> Type<'db> {
known_class: KnownClass, known_class: KnownClass,
) -> Option<Specialization<'db>> { ) -> Option<Specialization<'db>> {
let class_literal = known_class.try_to_class_literal(db)?; let class_literal = known_class.try_to_class_literal(db)?;
self.specialization_of(db, Some(class_literal)) self.specialization_of(db, class_literal)
}
// If this type is a class instance, returns its specialization.
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, None)
} }
// If the type is a specialized instance of the given class, returns the specialization. // If the type is a specialized instance of the given class, returns the specialization.
//
// If no class is provided, returns the specialization of any class instance.
pub(crate) fn specialization_of( pub(crate) fn specialization_of(
self,
db: &'db dyn Db,
expected_class: ClassLiteral<'_>,
) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, Some(expected_class))
}
fn specialization_of_optional(
self, self,
db: &'db dyn Db, db: &'db dyn Db,
expected_class: Option<ClassLiteral<'_>>, expected_class: Option<ClassLiteral<'_>>,
@ -1214,22 +1225,28 @@ impl<'db> Type<'db> {
/// If the type is a union, filters union elements based on the provided predicate. /// If the type is a union, filters union elements based on the provided predicate.
/// ///
/// Otherwise, returns the type unchanged. /// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
/// and filters it, returning `Never` if the predicate returns `false`, or the type
/// unchanged if `true`.
pub(crate) fn filter_union( pub(crate) fn filter_union(
self, self,
db: &'db dyn Db, db: &'db dyn Db,
f: impl FnMut(&Type<'db>) -> bool, mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> { ) -> Type<'db> {
if let Type::Union(union) = self { if let Type::Union(union) = self {
union.filter(db, f) union.filter(db, f)
} else { } else if f(&self) {
self self
} else {
Type::Never
} }
} }
/// If the type is a union, removes union elements that are disjoint from `target`. /// If the type is a union, removes union elements that are disjoint from `target`.
/// ///
/// Otherwise, returns the type unchanged. /// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
/// and filters it, returning `Never` if it is disjoint from `target`, or the type
/// unchanged if `true`.
pub(crate) fn filter_disjoint_elements( pub(crate) fn filter_disjoint_elements(
self, self,
db: &'db dyn Db, db: &'db dyn Db,

View File

@ -35,11 +35,11 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType}; use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{ use crate::types::{
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance, BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
todo_type, infer_isolated_expression, todo_type,
}; };
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@ -2718,9 +2718,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return; return;
}; };
let return_with_tcx = self
.signature
.return_ty
.zip(self.call_expression_tcx.annotation);
self.inferable_typevars = generic_context.inferable_typevars(self.db); self.inferable_typevars = generic_context.inferable_typevars(self.db);
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars); let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
let preferred_return_ty =
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
let return_ty =
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
builder.infer(return_ty, preferred_return_ty).ok()?;
Some(builder.type_mappings().clone())
});
let parameters = self.signature.parameters(); let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types() self.enumerate_argument_types()
@ -2733,9 +2749,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
continue; continue;
}; };
if let Err(error) = builder.infer( let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
// Avoid widening the inferred type if it is already assignable to the
// preferred declared type.
preferred_type_mappings
.as_ref()
.and_then(|types| types.get(&declared_ty))
.is_none_or(|preferred_ty| {
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
})
};
if let Err(error) = builder.infer_filter(
expected_type, expected_type,
variadic_argument_type.unwrap_or(argument_type), variadic_argument_type.unwrap_or(argument_type),
filter,
) { ) {
self.errors.push(BindingError::SpecializationError { self.errors.push(BindingError::SpecializationError {
error, error,
@ -2745,15 +2773,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
} }
} }
// Build the specialization first without inferring the type context. // Build the specialization first without inferring the complete type context.
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx); let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
let isolated_return_ty = self let isolated_return_ty = self
.return_ty .return_ty
.apply_specialization(self.db, isolated_specialization); .apply_specialization(self.db, isolated_specialization);
let mut try_infer_tcx = || { let mut try_infer_tcx = || {
let return_ty = self.signature.return_ty?; let (return_ty, call_expression_tcx) = return_with_tcx?;
let call_expression_tcx = self.call_expression_tcx.annotation?;
// A type variable is not a useful type-context for expression inference, and applying it // A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls. // to the return type can lead to confusing unions in nested generic calls.
@ -2762,7 +2789,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
} }
// If the return type is already assignable to the annotated type, we can ignore the // If the return type is already assignable to the annotated type, we can ignore the
// type context and prefer the narrower inferred type. // rest of the type context and prefer the narrower inferred type.
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) { if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
return None; return None;
} }
@ -2771,7 +2798,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// annotated assignment, to closer match the order of any unions written in the type annotation. // annotated assignment, to closer match the order of any unions written in the type annotation.
builder.infer(return_ty, call_expression_tcx).ok()?; builder.infer(return_ty, call_expression_tcx).ok()?;
// Otherwise, build the specialization again after inferring the type context. // Otherwise, build the specialization again after inferring the complete type context.
let specialization = builder.build(generic_context, *self.call_expression_tcx); let specialization = builder.build(generic_context, *self.call_expression_tcx);
let return_ty = return_ty.apply_specialization(self.db, specialization); let return_ty = return_ty.apply_specialization(self.db, specialization);

View File

@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> {
) -> Self { ) -> Self {
let tcx = tcx let tcx = tcx
.annotation .annotation
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db)))) .and_then(|ty| ty.specialization_of(db, self.origin(db)))
.map(|specialization| specialization.types(db)) .map(|specialization| specialization.types(db))
.unwrap_or(&[]); .unwrap_or(&[]);

View File

@ -1,4 +1,5 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::fmt::Display; use std::fmt::Display;
use itertools::Itertools; use itertools::Itertools;
@ -1319,6 +1320,11 @@ impl<'db> SpecializationBuilder<'db> {
} }
} }
/// Returns the current set of type mappings for this specialization.
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
&self.types
}
pub(crate) fn build( pub(crate) fn build(
&mut self, &mut self,
generic_context: GenericContext<'db>, generic_context: GenericContext<'db>,
@ -1326,7 +1332,7 @@ impl<'db> SpecializationBuilder<'db> {
) -> Specialization<'db> { ) -> Specialization<'db> {
let tcx_specialization = tcx let tcx_specialization = tcx
.annotation .annotation
.and_then(|annotation| annotation.specialization_of(self.db, None)); .and_then(|annotation| annotation.class_specialization(self.db));
let types = let types =
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| { (generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
@ -1349,19 +1355,43 @@ impl<'db> SpecializationBuilder<'db> {
generic_context.specialize_partial(self.db, types) generic_context.specialize_partial(self.db, types)
} }
fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) { fn add_type_mapping(
self.types &mut self,
.entry(bound_typevar.identity(self.db)) bound_typevar: BoundTypeVarInstance<'db>,
.and_modify(|existing| { ty: Type<'db>,
*existing = UnionType::from_elements(self.db, [*existing, ty]); filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
}) ) {
.or_insert(ty); let identity = bound_typevar.identity(self.db);
match self.types.entry(identity) {
Entry::Occupied(mut entry) => {
if filter(identity, ty) {
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
}
}
Entry::Vacant(entry) => {
entry.insert(ty);
}
}
} }
/// Infer type mappings for the specialization based on a given type and its declared type.
pub(crate) fn infer( pub(crate) fn infer(
&mut self, &mut self,
formal: Type<'db>, formal: Type<'db>,
actual: Type<'db>, actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
self.infer_filter(formal, actual, |_, _| true)
}
/// Infer type mappings for the specialization based on a given type and its declared type.
///
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
/// mappings to which the predicate returns `false` will be ignored.
pub(crate) fn infer_filter(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
) -> Result<(), SpecializationError<'db>> { ) -> Result<(), SpecializationError<'db>> {
if formal == actual { if formal == actual {
return Ok(()); return Ok(());
@ -1442,7 +1472,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() { if remaining_actual.is_never() {
return Ok(()); return Ok(());
} }
self.add_type_mapping(*formal_bound_typevar, remaining_actual); self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
} }
(Type::Union(formal), _) => { (Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not // Second, if the formal is a union, and precisely one union element _is_ a typevar (not
@ -1452,7 +1482,7 @@ impl<'db> SpecializationBuilder<'db> {
let bound_typevars = let bound_typevars =
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); (formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
if let Ok(bound_typevar) = bound_typevars.exactly_one() { if let Ok(bound_typevar) = bound_typevars.exactly_one() {
self.add_type_mapping(bound_typevar, actual); self.add_type_mapping(bound_typevar, actual, filter);
} }
} }
@ -1480,13 +1510,13 @@ impl<'db> SpecializationBuilder<'db> {
argument: ty, argument: ty,
}); });
} }
self.add_type_mapping(bound_typevar, ty); self.add_type_mapping(bound_typevar, ty, filter);
} }
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
// Prefer an exact match first. // Prefer an exact match first.
for constraint in constraints.elements(self.db) { for constraint in constraints.elements(self.db) {
if ty == *constraint { if ty == *constraint {
self.add_type_mapping(bound_typevar, ty); self.add_type_mapping(bound_typevar, ty, filter);
return Ok(()); return Ok(());
} }
} }
@ -1496,7 +1526,7 @@ impl<'db> SpecializationBuilder<'db> {
.when_assignable_to(self.db, *constraint, self.inferable) .when_assignable_to(self.db, *constraint, self.inferable)
.is_always_satisfied(self.db) .is_always_satisfied(self.db)
{ {
self.add_type_mapping(bound_typevar, *constraint); self.add_type_mapping(bound_typevar, *constraint, filter);
return Ok(()); return Ok(());
} }
} }
@ -1506,7 +1536,7 @@ impl<'db> SpecializationBuilder<'db> {
}); });
} }
_ => { _ => {
self.add_type_mapping(bound_typevar, ty); self.add_type_mapping(bound_typevar, ty, filter);
} }
} }
} }

View File

@ -6260,7 +6260,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx); let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx);
// Simplify the inference based on the declared type of the element. // Avoid widening the inferred type if it is already assignable to the preferred
// declared type.
if let Some(elt_tcx) = elt_tcx.annotation { if let Some(elt_tcx) = elt_tcx.annotation {
if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) { if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) {
continue; continue;