prefer declared type of generic classes

This commit is contained in:
Ibraheem Ahmed 2025-10-24 22:38:17 -04:00
parent 1d6ae8596a
commit 2f4dcbf651
7 changed files with 160 additions and 41 deletions

View File

@ -427,14 +427,13 @@ a = f("a")
reveal_type(a) # revealed: list[Literal["a"]]
b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[Literal["a"] | int]
reveal_type(b) # revealed: list[int | Literal["a"]]
c: list[int | str] = f("a")
reveal_type(c) # revealed: list[str | int]
reveal_type(c) # revealed: list[int | str]
d: list[int | tuple[int, int]] = f((1, 2))
# TODO: We could avoid reordering the union elements here.
reveal_type(d) # revealed: list[tuple[int, int] | int]
reveal_type(d) # revealed: list[int | tuple[int, int]]
e: list[int] = f(True)
reveal_type(e) # revealed: list[int]
@ -455,7 +454,54 @@ j: int | str = f2(True)
reveal_type(j) # revealed: Literal[True]
```
Types are not widened unnecessarily:
## Prefer the declared type of generic classes
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Any
def f[T](x: T) -> list[T]:
return [x]
def f2[T](x: T) -> list[T] | None:
return [x]
def f3[T](x: T) -> list[T] | dict[T, T]:
return [x]
a = f(1)
reveal_type(a) # revealed: list[Literal[1]]
b: list[Any] = f(1)
reveal_type(b) # revealed: list[Any]
c: list[Any] = [1]
reveal_type(c) # revealed: list[Any]
d: list[Any] | None = f(1)
reveal_type(d) # revealed: list[Any]
e: list[Any] | None = [1]
reveal_type(e) # revealed: list[Any]
f: list[Any] | None = f2(1)
reveal_type(f) # revealed: list[Any] | None
g: list[Any] | dict[Any, Any] = f3(1)
# TODO: Better constraint solver.
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
```
## Prefer the inferred type of non-generic classes
```toml
[environment]
python-version = "3.12"
```
```py
def id[T](x: T) -> T:

View File

@ -50,8 +50,6 @@ def _(l: list[int] | None = None):
def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
l5: int | list[int] = f(1, True)
```

View File

@ -891,13 +891,24 @@ impl<'db> Type<'db> {
known_class: KnownClass,
) -> Option<Specialization<'db>> {
let class_literal = known_class.try_to_class_literal(db)?;
self.specialization_of(db, Some(class_literal))
self.specialization_of(db, class_literal)
}
// If this type is a class instance, returns its specialization.
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, None)
}
// If the type is a specialized instance of the given class, returns the specialization.
//
// If no class is provided, returns the specialization of any class instance.
pub(crate) fn specialization_of(
self,
db: &'db dyn Db,
expected_class: ClassLiteral<'_>,
) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, Some(expected_class))
}
fn specialization_of_optional(
self,
db: &'db dyn Db,
expected_class: Option<ClassLiteral<'_>>,
@ -1214,22 +1225,28 @@ impl<'db> Type<'db> {
/// If the type is a union, filters union elements based on the provided predicate.
///
/// Otherwise, returns the type unchanged.
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
/// and filters it, returning `Never` if the predicate returns `false`, or the type
/// unchanged if `true`.
pub(crate) fn filter_union(
self,
db: &'db dyn Db,
f: impl FnMut(&Type<'db>) -> bool,
mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, f)
} else {
} else if f(&self) {
self
} else {
Type::Never
}
}
/// If the type is a union, removes union elements that are disjoint from `target`.
///
/// Otherwise, returns the type unchanged.
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
/// and filters it, returning `Never` if it is disjoint from `target`, or the type
/// unchanged if `true`.
pub(crate) fn filter_disjoint_elements(
self,
db: &'db dyn Db,

View File

@ -35,11 +35,11 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType,
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext,
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression,
todo_type,
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
infer_isolated_expression, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@ -2718,9 +2718,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return;
};
let return_with_tcx = self
.signature
.return_ty
.zip(self.call_expression_tcx.annotation);
self.inferable_typevars = generic_context.inferable_typevars(self.db);
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
let preferred_return_ty =
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
let return_ty =
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
builder.infer(return_ty, preferred_return_ty).ok()?;
Some(builder.type_mappings().clone())
});
let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types()
@ -2733,9 +2749,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
continue;
};
if let Err(error) = builder.infer(
let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
// Avoid widening the inferred type if it is already assignable to the
// preferred declared type.
preferred_type_mappings
.as_ref()
.and_then(|types| types.get(&declared_ty))
.is_none_or(|preferred_ty| {
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
})
};
if let Err(error) = builder.infer_filter(
expected_type,
variadic_argument_type.unwrap_or(argument_type),
filter,
) {
self.errors.push(BindingError::SpecializationError {
error,
@ -2745,15 +2773,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}
// Build the specialization first without inferring the type context.
// Build the specialization first without inferring the complete type context.
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
let isolated_return_ty = self
.return_ty
.apply_specialization(self.db, isolated_specialization);
let mut try_infer_tcx = || {
let return_ty = self.signature.return_ty?;
let call_expression_tcx = self.call_expression_tcx.annotation?;
let (return_ty, call_expression_tcx) = return_with_tcx?;
// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
@ -2762,7 +2789,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
// If the return type is already assignable to the annotated type, we can ignore the
// type context and prefer the narrower inferred type.
// rest of the type context and prefer the narrower inferred type.
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
return None;
}
@ -2771,7 +2798,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// annotated assignment, to closer match the order of any unions written in the type annotation.
builder.infer(return_ty, call_expression_tcx).ok()?;
// Otherwise, build the specialization again after inferring the type context.
// Otherwise, build the specialization again after inferring the complete type context.
let specialization = builder.build(generic_context, *self.call_expression_tcx);
let return_ty = return_ty.apply_specialization(self.db, specialization);

View File

@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> {
) -> Self {
let tcx = tcx
.annotation
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db))))
.and_then(|ty| ty.specialization_of(db, self.origin(db)))
.map(|specialization| specialization.types(db))
.unwrap_or(&[]);

View File

@ -1,4 +1,5 @@
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::fmt::Display;
use itertools::Itertools;
@ -1319,6 +1320,11 @@ impl<'db> SpecializationBuilder<'db> {
}
}
/// Returns the current set of type mappings for this specialization.
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
&self.types
}
pub(crate) fn build(
&mut self,
generic_context: GenericContext<'db>,
@ -1326,7 +1332,7 @@ impl<'db> SpecializationBuilder<'db> {
) -> Specialization<'db> {
let tcx_specialization = tcx
.annotation
.and_then(|annotation| annotation.specialization_of(self.db, None));
.and_then(|annotation| annotation.class_specialization(self.db));
let types =
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
@ -1349,19 +1355,43 @@ impl<'db> SpecializationBuilder<'db> {
generic_context.specialize_partial(self.db, types)
}
fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) {
self.types
.entry(bound_typevar.identity(self.db))
.and_modify(|existing| {
*existing = UnionType::from_elements(self.db, [*existing, ty]);
})
.or_insert(ty);
fn add_type_mapping(
&mut self,
bound_typevar: BoundTypeVarInstance<'db>,
ty: Type<'db>,
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
) {
let identity = bound_typevar.identity(self.db);
match self.types.entry(identity) {
Entry::Occupied(mut entry) => {
if filter(identity, ty) {
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
}
}
Entry::Vacant(entry) => {
entry.insert(ty);
}
}
}
/// Infer type mappings for the specialization based on a given type and its declared type.
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
self.infer_filter(formal, actual, |_, _| true)
}
/// Infer type mappings for the specialization based on a given type and its declared type.
///
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
/// mappings to which the predicate returns `false` will be ignored.
pub(crate) fn infer_filter(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
) -> Result<(), SpecializationError<'db>> {
if formal == actual {
return Ok(());
@ -1442,7 +1472,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() {
return Ok(());
}
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
}
(Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
@ -1452,7 +1482,7 @@ impl<'db> SpecializationBuilder<'db> {
let bound_typevars =
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
self.add_type_mapping(bound_typevar, actual);
self.add_type_mapping(bound_typevar, actual, filter);
}
}
@ -1480,13 +1510,13 @@ impl<'db> SpecializationBuilder<'db> {
argument: ty,
});
}
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
// Prefer an exact match first.
for constraint in constraints.elements(self.db) {
if ty == *constraint {
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
return Ok(());
}
}
@ -1496,7 +1526,7 @@ impl<'db> SpecializationBuilder<'db> {
.when_assignable_to(self.db, *constraint, self.inferable)
.is_always_satisfied(self.db)
{
self.add_type_mapping(bound_typevar, *constraint);
self.add_type_mapping(bound_typevar, *constraint, filter);
return Ok(());
}
}
@ -1506,7 +1536,7 @@ impl<'db> SpecializationBuilder<'db> {
});
}
_ => {
self.add_type_mapping(bound_typevar, ty);
self.add_type_mapping(bound_typevar, ty, filter);
}
}
}

View File

@ -6260,7 +6260,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx);
// Simplify the inference based on the declared type of the element.
// Avoid widening the inferred type if it is already assignable to the preferred
// declared type.
if let Some(elt_tcx) = elt_tcx.annotation {
if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) {
continue;