mirror of https://github.com/astral-sh/ruff
prefer declared type of generic classes
This commit is contained in:
parent
1d6ae8596a
commit
2f4dcbf651
|
|
@ -427,14 +427,13 @@ a = f("a")
|
|||
reveal_type(a) # revealed: list[Literal["a"]]
|
||||
|
||||
b: list[int | Literal["a"]] = f("a")
|
||||
reveal_type(b) # revealed: list[Literal["a"] | int]
|
||||
reveal_type(b) # revealed: list[int | Literal["a"]]
|
||||
|
||||
c: list[int | str] = f("a")
|
||||
reveal_type(c) # revealed: list[str | int]
|
||||
reveal_type(c) # revealed: list[int | str]
|
||||
|
||||
d: list[int | tuple[int, int]] = f((1, 2))
|
||||
# TODO: We could avoid reordering the union elements here.
|
||||
reveal_type(d) # revealed: list[tuple[int, int] | int]
|
||||
reveal_type(d) # revealed: list[int | tuple[int, int]]
|
||||
|
||||
e: list[int] = f(True)
|
||||
reveal_type(e) # revealed: list[int]
|
||||
|
|
@ -455,7 +454,54 @@ j: int | str = f2(True)
|
|||
reveal_type(j) # revealed: Literal[True]
|
||||
```
|
||||
|
||||
Types are not widened unnecessarily:
|
||||
## Prefer the declared type of generic classes
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import Any
|
||||
|
||||
def f[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
def f2[T](x: T) -> list[T] | None:
|
||||
return [x]
|
||||
|
||||
def f3[T](x: T) -> list[T] | dict[T, T]:
|
||||
return [x]
|
||||
|
||||
a = f(1)
|
||||
reveal_type(a) # revealed: list[Literal[1]]
|
||||
|
||||
b: list[Any] = f(1)
|
||||
reveal_type(b) # revealed: list[Any]
|
||||
|
||||
c: list[Any] = [1]
|
||||
reveal_type(c) # revealed: list[Any]
|
||||
|
||||
d: list[Any] | None = f(1)
|
||||
reveal_type(d) # revealed: list[Any]
|
||||
|
||||
e: list[Any] | None = [1]
|
||||
reveal_type(e) # revealed: list[Any]
|
||||
|
||||
f: list[Any] | None = f2(1)
|
||||
reveal_type(f) # revealed: list[Any] | None
|
||||
|
||||
g: list[Any] | dict[Any, Any] = f3(1)
|
||||
# TODO: Better constraint solver.
|
||||
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
|
||||
```
|
||||
|
||||
## Prefer the inferred type of non-generic classes
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
def id[T](x: T) -> T:
|
||||
|
|
|
|||
|
|
@ -50,8 +50,6 @@ def _(l: list[int] | None = None):
|
|||
def f[T](x: T, cond: bool) -> T | list[T]:
|
||||
return x if cond else [x]
|
||||
|
||||
# TODO: no error
|
||||
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
|
||||
l5: int | list[int] = f(1, True)
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -891,13 +891,24 @@ impl<'db> Type<'db> {
|
|||
known_class: KnownClass,
|
||||
) -> Option<Specialization<'db>> {
|
||||
let class_literal = known_class.try_to_class_literal(db)?;
|
||||
self.specialization_of(db, Some(class_literal))
|
||||
self.specialization_of(db, class_literal)
|
||||
}
|
||||
|
||||
// If this type is a class instance, returns its specialization.
|
||||
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
|
||||
self.specialization_of_optional(db, None)
|
||||
}
|
||||
|
||||
// If the type is a specialized instance of the given class, returns the specialization.
|
||||
//
|
||||
// If no class is provided, returns the specialization of any class instance.
|
||||
pub(crate) fn specialization_of(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
expected_class: ClassLiteral<'_>,
|
||||
) -> Option<Specialization<'db>> {
|
||||
self.specialization_of_optional(db, Some(expected_class))
|
||||
}
|
||||
|
||||
fn specialization_of_optional(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
expected_class: Option<ClassLiteral<'_>>,
|
||||
|
|
@ -1214,22 +1225,28 @@ impl<'db> Type<'db> {
|
|||
|
||||
/// If the type is a union, filters union elements based on the provided predicate.
|
||||
///
|
||||
/// Otherwise, returns the type unchanged.
|
||||
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
|
||||
/// and filters it, returning `Never` if the predicate returns `false`, or the type
|
||||
/// unchanged if `true`.
|
||||
pub(crate) fn filter_union(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
f: impl FnMut(&Type<'db>) -> bool,
|
||||
mut f: impl FnMut(&Type<'db>) -> bool,
|
||||
) -> Type<'db> {
|
||||
if let Type::Union(union) = self {
|
||||
union.filter(db, f)
|
||||
} else {
|
||||
} else if f(&self) {
|
||||
self
|
||||
} else {
|
||||
Type::Never
|
||||
}
|
||||
}
|
||||
|
||||
/// If the type is a union, removes union elements that are disjoint from `target`.
|
||||
///
|
||||
/// Otherwise, returns the type unchanged.
|
||||
/// Otherwise, considers the type to be the sole inhabitant of a single-valued union,
|
||||
/// and filters it, returning `Never` if it is disjoint from `target`, or the type
|
||||
/// unchanged if `true`.
|
||||
pub(crate) fn filter_disjoint_elements(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
|
|
|
|||
|
|
@ -35,11 +35,11 @@ use crate::types::generics::{
|
|||
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
|
||||
use crate::types::tuple::{TupleLength, TupleType};
|
||||
use crate::types::{
|
||||
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
|
||||
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType,
|
||||
PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext,
|
||||
UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression,
|
||||
todo_type,
|
||||
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
|
||||
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
|
||||
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
|
||||
TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
|
||||
infer_isolated_expression, todo_type,
|
||||
};
|
||||
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
||||
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
|
||||
|
|
@ -2718,9 +2718,25 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
return;
|
||||
};
|
||||
|
||||
let return_with_tcx = self
|
||||
.signature
|
||||
.return_ty
|
||||
.zip(self.call_expression_tcx.annotation);
|
||||
|
||||
self.inferable_typevars = generic_context.inferable_typevars(self.db);
|
||||
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
|
||||
|
||||
// Prefer the declared type of generic classes.
|
||||
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
|
||||
let preferred_return_ty =
|
||||
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
|
||||
let return_ty =
|
||||
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
|
||||
|
||||
builder.infer(return_ty, preferred_return_ty).ok()?;
|
||||
Some(builder.type_mappings().clone())
|
||||
});
|
||||
|
||||
let parameters = self.signature.parameters();
|
||||
for (argument_index, adjusted_argument_index, _, argument_type) in
|
||||
self.enumerate_argument_types()
|
||||
|
|
@ -2733,9 +2749,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
continue;
|
||||
};
|
||||
|
||||
if let Err(error) = builder.infer(
|
||||
let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
|
||||
// Avoid widening the inferred type if it is already assignable to the
|
||||
// preferred declared type.
|
||||
preferred_type_mappings
|
||||
.as_ref()
|
||||
.and_then(|types| types.get(&declared_ty))
|
||||
.is_none_or(|preferred_ty| {
|
||||
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
|
||||
})
|
||||
};
|
||||
|
||||
if let Err(error) = builder.infer_filter(
|
||||
expected_type,
|
||||
variadic_argument_type.unwrap_or(argument_type),
|
||||
filter,
|
||||
) {
|
||||
self.errors.push(BindingError::SpecializationError {
|
||||
error,
|
||||
|
|
@ -2745,15 +2773,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
}
|
||||
}
|
||||
|
||||
// Build the specialization first without inferring the type context.
|
||||
// Build the specialization first without inferring the complete type context.
|
||||
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
|
||||
let isolated_return_ty = self
|
||||
.return_ty
|
||||
.apply_specialization(self.db, isolated_specialization);
|
||||
|
||||
let mut try_infer_tcx = || {
|
||||
let return_ty = self.signature.return_ty?;
|
||||
let call_expression_tcx = self.call_expression_tcx.annotation?;
|
||||
let (return_ty, call_expression_tcx) = return_with_tcx?;
|
||||
|
||||
// A type variable is not a useful type-context for expression inference, and applying it
|
||||
// to the return type can lead to confusing unions in nested generic calls.
|
||||
|
|
@ -2762,7 +2789,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
}
|
||||
|
||||
// If the return type is already assignable to the annotated type, we can ignore the
|
||||
// type context and prefer the narrower inferred type.
|
||||
// rest of the type context and prefer the narrower inferred type.
|
||||
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -2771,7 +2798,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
// annotated assignment, to closer match the order of any unions written in the type annotation.
|
||||
builder.infer(return_ty, call_expression_tcx).ok()?;
|
||||
|
||||
// Otherwise, build the specialization again after inferring the type context.
|
||||
// Otherwise, build the specialization again after inferring the complete type context.
|
||||
let specialization = builder.build(generic_context, *self.call_expression_tcx);
|
||||
let return_ty = return_ty.apply_specialization(self.db, specialization);
|
||||
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> {
|
|||
) -> Self {
|
||||
let tcx = tcx
|
||||
.annotation
|
||||
.and_then(|ty| ty.specialization_of(db, Some(self.origin(db))))
|
||||
.and_then(|ty| ty.specialization_of(db, self.origin(db)))
|
||||
.map(|specialization| specialization.types(db))
|
||||
.unwrap_or(&[]);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::fmt::Display;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
|
@ -1319,6 +1320,11 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the current set of type mappings for this specialization.
|
||||
pub(crate) fn type_mappings(&self) -> &FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>> {
|
||||
&self.types
|
||||
}
|
||||
|
||||
pub(crate) fn build(
|
||||
&mut self,
|
||||
generic_context: GenericContext<'db>,
|
||||
|
|
@ -1326,7 +1332,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
) -> Specialization<'db> {
|
||||
let tcx_specialization = tcx
|
||||
.annotation
|
||||
.and_then(|annotation| annotation.specialization_of(self.db, None));
|
||||
.and_then(|annotation| annotation.class_specialization(self.db));
|
||||
|
||||
let types =
|
||||
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
|
||||
|
|
@ -1349,19 +1355,43 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
generic_context.specialize_partial(self.db, types)
|
||||
}
|
||||
|
||||
fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) {
|
||||
self.types
|
||||
.entry(bound_typevar.identity(self.db))
|
||||
.and_modify(|existing| {
|
||||
*existing = UnionType::from_elements(self.db, [*existing, ty]);
|
||||
})
|
||||
.or_insert(ty);
|
||||
fn add_type_mapping(
|
||||
&mut self,
|
||||
bound_typevar: BoundTypeVarInstance<'db>,
|
||||
ty: Type<'db>,
|
||||
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
|
||||
) {
|
||||
let identity = bound_typevar.identity(self.db);
|
||||
match self.types.entry(identity) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
if filter(identity, ty) {
|
||||
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
|
||||
}
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer type mappings for the specialization based on a given type and its declared type.
|
||||
pub(crate) fn infer(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
self.infer_filter(formal, actual, |_, _| true)
|
||||
}
|
||||
|
||||
/// Infer type mappings for the specialization based on a given type and its declared type.
|
||||
///
|
||||
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
|
||||
/// mappings to which the predicate returns `false` will be ignored.
|
||||
pub(crate) fn infer_filter(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
if formal == actual {
|
||||
return Ok(());
|
||||
|
|
@ -1442,7 +1472,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
if remaining_actual.is_never() {
|
||||
return Ok(());
|
||||
}
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
|
||||
}
|
||||
(Type::Union(formal), _) => {
|
||||
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
|
||||
|
|
@ -1452,7 +1482,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
let bound_typevars =
|
||||
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
|
||||
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
|
||||
self.add_type_mapping(bound_typevar, actual);
|
||||
self.add_type_mapping(bound_typevar, actual, filter);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1480,13 +1510,13 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
argument: ty,
|
||||
});
|
||||
}
|
||||
self.add_type_mapping(bound_typevar, ty);
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
}
|
||||
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
|
||||
// Prefer an exact match first.
|
||||
for constraint in constraints.elements(self.db) {
|
||||
if ty == *constraint {
|
||||
self.add_type_mapping(bound_typevar, ty);
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
|
@ -1496,7 +1526,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
.when_assignable_to(self.db, *constraint, self.inferable)
|
||||
.is_always_satisfied(self.db)
|
||||
{
|
||||
self.add_type_mapping(bound_typevar, *constraint);
|
||||
self.add_type_mapping(bound_typevar, *constraint, filter);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
|
@ -1506,7 +1536,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
});
|
||||
}
|
||||
_ => {
|
||||
self.add_type_mapping(bound_typevar, ty);
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6260,7 +6260,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
let inferred_elt_ty = self.get_or_infer_expression(elt, elt_tcx);
|
||||
|
||||
// Simplify the inference based on the declared type of the element.
|
||||
// Avoid widening the inferred type if it is already assignable to the preferred
|
||||
// declared type.
|
||||
if let Some(elt_tcx) = elt_tcx.annotation {
|
||||
if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) {
|
||||
continue;
|
||||
|
|
|
|||
Loading…
Reference in New Issue