From 2f4dcbf651080105a6085bdd58de3b0f5aa561b5 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 24 Oct 2025 22:38:17 -0400 Subject: [PATCH] prefer declared type of generic classes --- .../mdtest/assignment/annotations.md | 56 ++++++++++++++++-- .../resources/mdtest/bidirectional.md | 2 - crates/ty_python_semantic/src/types.rs | 31 +++++++--- .../ty_python_semantic/src/types/call/bind.rs | 49 ++++++++++++---- crates/ty_python_semantic/src/types/class.rs | 2 +- .../ty_python_semantic/src/types/generics.rs | 58 ++++++++++++++----- .../src/types/infer/builder.rs | 3 +- 7 files changed, 160 insertions(+), 41 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index adf0de358d..d0cb30f94d 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -427,14 +427,13 @@ a = f("a") reveal_type(a) # revealed: list[Literal["a"]] b: list[int | Literal["a"]] = f("a") -reveal_type(b) # revealed: list[Literal["a"] | int] +reveal_type(b) # revealed: list[int | Literal["a"]] c: list[int | str] = f("a") -reveal_type(c) # revealed: list[str | int] +reveal_type(c) # revealed: list[int | str] d: list[int | tuple[int, int]] = f((1, 2)) -# TODO: We could avoid reordering the union elements here. -reveal_type(d) # revealed: list[tuple[int, int] | int] +reveal_type(d) # revealed: list[int | tuple[int, int]] e: list[int] = f(True) reveal_type(e) # revealed: list[int] @@ -455,7 +454,54 @@ j: int | str = f2(True) reveal_type(j) # revealed: Literal[True] ``` -Types are not widened unnecessarily: +## Prefer the declared type of generic classes + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Any + +def f[T](x: T) -> list[T]: + return [x] + +def f2[T](x: T) -> list[T] | None: + return [x] + +def f3[T](x: T) -> list[T] | dict[T, T]: + return [x] + +a = f(1) +reveal_type(a) # revealed: list[Literal[1]] + +b: list[Any] = f(1) +reveal_type(b) # revealed: list[Any] + +c: list[Any] = [1] +reveal_type(c) # revealed: list[Any] + +d: list[Any] | None = f(1) +reveal_type(d) # revealed: list[Any] + +e: list[Any] | None = [1] +reveal_type(e) # revealed: list[Any] + +f: list[Any] | None = f2(1) +reveal_type(f) # revealed: list[Any] | None + +g: list[Any] | dict[Any, Any] = f3(1) +# TODO: Better constraint solver. +reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]] +``` + +## Prefer the inferred type of non-generic classes + +```toml +[environment] +python-version = "3.12" +``` ```py def id[T](x: T) -> T: diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 627492855f..ba4638ba66 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -50,8 +50,6 @@ def _(l: list[int] | None = None): def f[T](x: T, cond: bool) -> T | list[T]: return x if cond else [x] -# TODO: no error -# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`" l5: int | list[int] = f(1, True) ``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index a4eb563e6a..5c3899f791 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -891,13 +891,24 @@ impl<'db> Type<'db> { known_class: KnownClass, ) -> Option> { let class_literal = known_class.try_to_class_literal(db)?; - self.specialization_of(db, Some(class_literal)) + self.specialization_of(db, class_literal) + } + + // If this type is a class instance, returns its specialization. + pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option> { + self.specialization_of_optional(db, None) } // If the type is a specialized instance of the given class, returns the specialization. - // - // If no class is provided, returns the specialization of any class instance. pub(crate) fn specialization_of( + self, + db: &'db dyn Db, + expected_class: ClassLiteral<'_>, + ) -> Option> { + self.specialization_of_optional(db, Some(expected_class)) + } + + fn specialization_of_optional( self, db: &'db dyn Db, expected_class: Option>, @@ -1214,22 +1225,28 @@ impl<'db> Type<'db> { /// If the type is a union, filters union elements based on the provided predicate. /// - /// Otherwise, returns the type unchanged. + /// Otherwise, considers the type to be the sole inhabitant of a single-valued union, + /// and filters it, returning `Never` if the predicate returns `false`, or the type + /// unchanged if `true`. pub(crate) fn filter_union( self, db: &'db dyn Db, - f: impl FnMut(&Type<'db>) -> bool, + mut f: impl FnMut(&Type<'db>) -> bool, ) -> Type<'db> { if let Type::Union(union) = self { union.filter(db, f) - } else { + } else if f(&self) { self + } else { + Type::Never } } /// If the type is a union, removes union elements that are disjoint from `target`. /// - /// Otherwise, returns the type unchanged. + /// Otherwise, considers the type to be the sole inhabitant of a single-valued union, + /// and filters it, returning `Never` if it is disjoint from `target`, or the type + /// unchanged if `true`. pub(crate) fn filter_disjoint_elements( self, db: &'db dyn Db, diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index b0a5cc1b91..04f8c89b2f 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -35,11 +35,11 @@ use crate::types::generics::{ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ - BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance, - KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType, - PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, - UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, - todo_type, + BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams, + FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, + NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, + TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, + infer_isolated_expression, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; @@ -2718,9 +2718,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { return; }; + let return_with_tcx = self + .signature + .return_ty + .zip(self.call_expression_tcx.annotation); + self.inferable_typevars = generic_context.inferable_typevars(self.db); let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars); + // Prefer the declared type of generic classes. + let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| { + let preferred_return_ty = + tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()); + let return_ty = + return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()); + + builder.infer(return_ty, preferred_return_ty).ok()?; + Some(builder.type_mappings().clone()) + }); + let parameters = self.signature.parameters(); for (argument_index, adjusted_argument_index, _, argument_type) in self.enumerate_argument_types() @@ -2733,9 +2749,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { continue; }; - if let Err(error) = builder.infer( + let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| { + // Avoid widening the inferred type if it is already assignable to the + // preferred declared type. + preferred_type_mappings + .as_ref() + .and_then(|types| types.get(&declared_ty)) + .is_none_or(|preferred_ty| { + !inferred_ty.is_assignable_to(self.db, *preferred_ty) + }) + }; + + if let Err(error) = builder.infer_filter( expected_type, variadic_argument_type.unwrap_or(argument_type), + filter, ) { self.errors.push(BindingError::SpecializationError { error, @@ -2745,15 +2773,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } } - // Build the specialization first without inferring the type context. + // Build the specialization first without inferring the complete type context. let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx); let isolated_return_ty = self .return_ty .apply_specialization(self.db, isolated_specialization); let mut try_infer_tcx = || { - let return_ty = self.signature.return_ty?; - let call_expression_tcx = self.call_expression_tcx.annotation?; + let (return_ty, call_expression_tcx) = return_with_tcx?; // A type variable is not a useful type-context for expression inference, and applying it // to the return type can lead to confusing unions in nested generic calls. @@ -2762,7 +2789,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } // If the return type is already assignable to the annotated type, we can ignore the - // type context and prefer the narrower inferred type. + // rest of the type context and prefer the narrower inferred type. if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) { return None; } @@ -2771,7 +2798,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // annotated assignment, to closer match the order of any unions written in the type annotation. builder.infer(return_ty, call_expression_tcx).ok()?; - // Otherwise, build the specialization again after inferring the type context. + // Otherwise, build the specialization again after inferring the complete type context. let specialization = builder.build(generic_context, *self.call_expression_tcx); let return_ty = return_ty.apply_specialization(self.db, specialization); diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index c3ff51e47f..271d8e2c49 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> { ) -> Self { let tcx = tcx .annotation - .and_then(|ty| ty.specialization_of(db, Some(self.origin(db)))) + .and_then(|ty| ty.specialization_of(db, self.origin(db))) .map(|specialization| specialization.types(db)) .unwrap_or(&[]); diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 8485931ff2..069065030f 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::hash_map::Entry; use std::fmt::Display; use itertools::Itertools; @@ -1319,6 +1320,11 @@ impl<'db> SpecializationBuilder<'db> { } } + /// Returns the current set of type mappings for this specialization. + pub(crate) fn type_mappings(&self) -> &FxHashMap, Type<'db>> { + &self.types + } + pub(crate) fn build( &mut self, generic_context: GenericContext<'db>, @@ -1326,7 +1332,7 @@ impl<'db> SpecializationBuilder<'db> { ) -> Specialization<'db> { let tcx_specialization = tcx .annotation - .and_then(|annotation| annotation.specialization_of(self.db, None)); + .and_then(|annotation| annotation.class_specialization(self.db)); let types = (generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| { @@ -1349,19 +1355,43 @@ impl<'db> SpecializationBuilder<'db> { generic_context.specialize_partial(self.db, types) } - fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) { - self.types - .entry(bound_typevar.identity(self.db)) - .and_modify(|existing| { - *existing = UnionType::from_elements(self.db, [*existing, ty]); - }) - .or_insert(ty); + fn add_type_mapping( + &mut self, + bound_typevar: BoundTypeVarInstance<'db>, + ty: Type<'db>, + filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool, + ) { + let identity = bound_typevar.identity(self.db); + match self.types.entry(identity) { + Entry::Occupied(mut entry) => { + if filter(identity, ty) { + *entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]); + } + } + Entry::Vacant(entry) => { + entry.insert(ty); + } + } } + /// Infer type mappings for the specialization based on a given type and its declared type. pub(crate) fn infer( &mut self, formal: Type<'db>, actual: Type<'db>, + ) -> Result<(), SpecializationError<'db>> { + self.infer_filter(formal, actual, |_, _| true) + } + + /// Infer type mappings for the specialization based on a given type and its declared type. + /// + /// The filter predicate is provided with a type variable and the type being mapped to it. Type + /// mappings to which the predicate returns `false` will be ignored. + pub(crate) fn infer_filter( + &mut self, + formal: Type<'db>, + actual: Type<'db>, + filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool, ) -> Result<(), SpecializationError<'db>> { if formal == actual { return Ok(()); @@ -1442,7 +1472,7 @@ impl<'db> SpecializationBuilder<'db> { if remaining_actual.is_never() { return Ok(()); } - self.add_type_mapping(*formal_bound_typevar, remaining_actual); + self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter); } (Type::Union(formal), _) => { // Second, if the formal is a union, and precisely one union element _is_ a typevar (not @@ -1452,7 +1482,7 @@ impl<'db> SpecializationBuilder<'db> { let bound_typevars = (formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); if let Ok(bound_typevar) = bound_typevars.exactly_one() { - self.add_type_mapping(bound_typevar, actual); + self.add_type_mapping(bound_typevar, actual, filter); } } @@ -1480,13 +1510,13 @@ impl<'db> SpecializationBuilder<'db> { argument: ty, }); } - self.add_type_mapping(bound_typevar, ty); + self.add_type_mapping(bound_typevar, ty, filter); } Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { // Prefer an exact match first. for constraint in constraints.elements(self.db) { if ty == *constraint { - self.add_type_mapping(bound_typevar, ty); + self.add_type_mapping(bound_typevar, ty, filter); return Ok(()); } } @@ -1496,7 +1526,7 @@ impl<'db> SpecializationBuilder<'db> { .when_assignable_to(self.db, *constraint, self.inferable) .is_always_satisfied(self.db) { - self.add_type_mapping(bound_typevar, *constraint); + self.add_type_mapping(bound_typevar, *constraint, filter); return Ok(()); } } @@ -1506,7 +1536,7 @@ impl<'db> SpecializationBuilder<'db> { }); } _ => { - self.add_type_mapping(bound_typevar, ty); + self.add_type_mapping(bound_typevar, ty, filter); } } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index edf8581bcd..d28d757c01 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -6260,7 +6260,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx); - // Simplify the inference based on the declared type of the element. + // Avoid widening the inferred type if it is already assignable to the preferred + // declared type. if let Some(elt_tcx) = elt_tcx.annotation { if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) { continue;