diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index d8595041fc..c0f42e204b 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -402,40 +402,48 @@ python-version = "3.12" `generic_list.py`: ```py -from typing import Literal +from typing import Literal, Sequence def f[T](x: T) -> list[T]: return [x] -a = f("a") -reveal_type(a) # revealed: list[str] +x1 = f("a") +reveal_type(x1) # revealed: list[str] -b: list[int | Literal["a"]] = f("a") -reveal_type(b) # revealed: list[int | Literal["a"]] +x2: list[int | Literal["a"]] = f("a") +reveal_type(x2) # revealed: list[int | Literal["a"]] -c: list[int | str] = f("a") -reveal_type(c) # revealed: list[int | str] +x3: list[int | str] = f("a") +reveal_type(x3) # revealed: list[int | str] -d: list[int | tuple[int, int]] = f((1, 2)) -reveal_type(d) # revealed: list[int | tuple[int, int]] +x4: list[int | tuple[int, int]] = f((1, 2)) +reveal_type(x4) # revealed: list[int | tuple[int, int]] -e: list[int] = f(True) -reveal_type(e) # revealed: list[int] +x5: list[int] = f(True) +reveal_type(x5) # revealed: list[int] # error: [invalid-assignment] "Object of type `list[int | str]` is not assignable to `list[int]`" -g: list[int] = f("a") +x6: list[int] = f("a") # error: [invalid-assignment] "Object of type `list[str]` is not assignable to `tuple[int]`" -h: tuple[int] = f("a") +x7: tuple[int] = f("a") def f2[T: int](x: T) -> T: return x -i: int = f2(True) -reveal_type(i) # revealed: Literal[True] +x8: int = f2(True) +reveal_type(x8) # revealed: Literal[True] -j: int | str = f2(True) -reveal_type(j) # revealed: Literal[True] +x9: int | str = f2(True) +reveal_type(x9) # revealed: Literal[True] + +# TODO: We could choose a concrete type here. +x10: list[int | str] | list[int | None] = [1, 2, 3] +reveal_type(x10) # revealed: list[Unknown | int] + +# TODO: And here similarly. +x11: Sequence[int | str] | Sequence[int | None] = [1, 2, 3] +reveal_type(x11) # revealed: list[Unknown | int] ``` A function's arguments are also inferred using the type context: @@ -610,6 +618,73 @@ x1: X[int | None] = X() reveal_type(x1) # revealed: X[None] ``` +## Declared type preference sees through subtyping + +```toml +[environment] +python-version = "3.12" +``` + +Similarly, if the inferred type is a subtype of the declared type, we prefer declared type +assignments that are in non-covariant position. + +```py +from collections import defaultdict +from typing import Any, Iterable, Literal, MutableSequence, Sequence + +x1: Sequence[Any] = [1, 2, 3] +reveal_type(x1) # revealed: list[int] + +x2: MutableSequence[Any] = [1, 2, 3] +reveal_type(x2) # revealed: list[Any] + +x3: Iterable[Any] = [1, 2, 3] +reveal_type(x3) # revealed: list[int] + +x4: Iterable[Iterable[Any]] = [[1, 2, 3]] +reveal_type(x4) # revealed: list[list[int]] + +x5: list[Iterable[Any]] = [[1, 2, 3]] +reveal_type(x5) # revealed: list[Iterable[Any]] + +x6: Iterable[list[Any]] = [[1, 2, 3]] +reveal_type(x6) # revealed: list[list[Any]] + +class X[T]: + value: T + + def __init__(self, value: T): ... + +class A[T](X[T]): ... + +def a[T](value: T) -> A[T]: + return A(value) + +x7: A[object] = A(1) +reveal_type(x7) # revealed: A[object] + +x8: X[object] = A(1) +reveal_type(x8) # revealed: A[object] + +x9: X[object] | None = A(1) +reveal_type(x9) # revealed: A[object] + +x10: X[object] | None = a(1) +reveal_type(x10) # revealed: A[object] + +def f[T](x: T) -> list[list[T]]: + return [[x]] + +x11: Sequence[Sequence[Any]] = f(1) +reveal_type(x11) # revealed: list[list[int]] + +x12: Sequence[list[Any]] = f(1) +reveal_type(x12) # revealed: list[list[Any]] + +x13: dict[int, dict[str, int]] = defaultdict(dict) +reveal_type(x13) # revealed: defaultdict[int, dict[str, int]] +``` + ## Narrow generic unions ```toml diff --git a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md index 65fc1c1602..a7a8504d01 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md +++ b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md @@ -341,3 +341,58 @@ reveal_type(x21) # revealed: X[Literal[1]] x22: X[Literal[1]] | None = x(1) reveal_type(x22) # revealed: X[Literal[1]] ``` + +## Literal annotations see through subtyping + +```py +from typing import Any, Iterable, Literal, MutableSequence, Sequence + +x1: Sequence[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x1) # revealed: list[Literal[1, 2, 3]] + +x2: MutableSequence[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x2) # revealed: list[Literal[1, 2, 3]] + +x3: Iterable[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x3) # revealed: list[Literal[1, 2, 3]] + +class Sup1[T]: + value: T + +class Sub1[T](Sup1[T]): ... + +def sub1[T](value: T) -> Sub1[T]: + return Sub1() + +x4: Sub1[Literal[1]] = sub1(1) +reveal_type(x4) # revealed: Sub1[Literal[1]] + +x5: Sup1[Literal[1]] = sub1(1) +reveal_type(x5) # revealed: Sub1[Literal[1]] + +x6: Sup1[Literal[1]] | None = sub1(1) +reveal_type(x6) # revealed: Sub1[Literal[1]] + +x7: Sup1[Literal[1]] | None = sub1(1) +reveal_type(x7) # revealed: Sub1[Literal[1]] + +class Sup2A[T, U]: + value: tuple[T, U] + +class Sup2B[T, U]: + value: tuple[T, U] + +class Sub2[T, U](Sup2A[T, Any], Sup2B[Any, U]): ... + +def sub2[T, U](x: T, y: U) -> Sub2[T, U]: + return Sub2() + +x8 = sub2(1, 2) +reveal_type(x8) # revealed: Sub2[int, int] + +x9: Sup2A[Literal[1], Literal[2]] = sub2(1, 2) +reveal_type(x9) # revealed: Sub2[Literal[1], int] + +x10: Sup2B[Literal[1], Literal[2]] = sub2(1, 2) +reveal_type(x10) # revealed: Sub2[int, Literal[2]] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md index e323d25a17..e2c45cc7f1 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md +++ b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md @@ -57,6 +57,9 @@ reveal_type(tuple((1, 2))) # revealed: tuple[Literal[1], Literal[2]] reveal_type(tuple([1])) # revealed: tuple[Unknown | int, ...] +x1: tuple[int, ...] = tuple([1]) +reveal_type(x1) # revealed: tuple[int, ...] + # error: [invalid-argument-type] reveal_type(tuple[int]([1])) # revealed: tuple[int] diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 40451a0736..9d259b7e90 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4,6 +4,7 @@ use itertools::{Either, Itertools}; use ruff_diagnostics::{Edit, Fix}; use std::borrow::Cow; +use std::cell::RefCell; use std::time::Duration; use bitflags::bitflags; @@ -61,8 +62,8 @@ use crate::types::function::{ }; pub(crate) use crate::types::generics::GenericContext; use crate::types::generics::{ - InferableTypeVars, PartialSpecialization, Specialization, bind_typevar, typing_self, - walk_generic_context, + InferableTypeVars, PartialSpecialization, Specialization, SpecializationBuilder, bind_typevar, + typing_self, walk_generic_context, }; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; @@ -1100,7 +1101,10 @@ impl<'db> Type<'db> { } /// If this type is a class instance, returns its specialization. - pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option> { + pub(crate) fn class_specialization( + self, + db: &'db dyn Db, + ) -> Option<(ClassLiteral<'db>, Specialization<'db>)> { self.specialization_of_optional(db, None) } @@ -1111,15 +1115,17 @@ impl<'db> Type<'db> { expected_class: ClassLiteral<'_>, ) -> Option> { self.specialization_of_optional(db, Some(expected_class)) + .map(|(_, specialization)| specialization) } fn specialization_of_optional( self, db: &'db dyn Db, expected_class: Option>, - ) -> Option> { + ) -> Option<(ClassLiteral<'db>, Specialization<'db>)> { let class_type = match self { Type::NominalInstance(instance) => instance, + Type::ProtocolInstance(instance) => instance.to_nominal_instance()?, Type::TypeAlias(alias) => alias.value_type(db).as_nominal_instance()?, _ => return None, } @@ -1130,7 +1136,7 @@ impl<'db> Type<'db> { return None; } - specialization + Some((class_literal, specialization?)) } /// Returns the top materialization (or upper bound materialization) of this type, which is the @@ -4114,28 +4120,32 @@ impl<'db> Type<'db> { return; }; - let (class_literal, Some(specialization)) = instance.class(db).class_literal(db) else { return; }; + let generic_context = specialization.generic_context(db); - let tcx_specialization = tcx.annotation.and_then(|tcx| { - tcx.filter_union(db, |ty| ty.specialization_of(db, class_literal).is_some()) - .specialization_of(db, class_literal) - }); + // Collect the type mappings used to narrow the type context. + let tcx_mappings = { + let mut builder = + SpecializationBuilder::new(db, generic_context.inferable_typevars(db)); - for (typevar, ty) in specialization - .generic_context(db) - .variables(db) - .zip(specialization.types(db)) - { - let variance = typevar.variance_with_polarity(db, polarity); - let tcx = TypeContext::new(tcx_specialization.and_then(|spec| spec.get(db, typevar))); + if let Some(tcx) = tcx.annotation { + let alias_instance = Type::instance(db, class_literal.identity_specialization(db)); + let _ = builder.infer_reverse(tcx, alias_instance); + } - f(typevar, *ty, variance, tcx); + builder.into_type_mappings() + }; + + for (type_var, ty) in generic_context.variables(db).zip(specialization.types(db)) { + let variance = type_var.variance_with_polarity(db, polarity); + let narrowed_tcx = TypeContext::new(tcx_mappings.get(&type_var.identity(db)).copied()); + + f(type_var, *ty, variance, narrowed_tcx); visitor.visit(*ty, || { - ty.visit_specialization_impl(db, tcx, variance, f, visitor); + ty.visit_specialization_impl(db, narrowed_tcx, variance, f, visitor); }); } } @@ -5841,8 +5851,11 @@ impl<'db> Type<'db> { } Some(KnownFunction::AssertType) => { - let val_ty = - BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Invariant); + let val_ty = BoundTypeVarInstance::synthetic( + db, + Name::new_static("T"), + TypeVarVariance::Invariant, + ); Binding::single( self, @@ -6336,30 +6349,38 @@ impl<'db> Type<'db> { } Some(KnownClass::Tuple) => { - let object = Type::object(); + let element_ty = BoundTypeVarInstance::synthetic( + db, + Name::new_static("T"), + TypeVarVariance::Covariant, + ); // ```py - // class tuple: + // class tuple(Sequence[_T_co]): // @overload // def __new__(cls) -> tuple[()]: ... // @overload - // def __new__(cls, iterable: Iterable[object]) -> tuple[object, ...]: ... + // def __new__(cls, iterable: Iterable[_T_co]) -> tuple[_T_co, ...]: ... // ``` CallableBinding::from_overloads( self, [ Signature::new(Parameters::empty(), Some(Type::empty_tuple(db))), - Signature::new( + Signature::new_generic( + Some(GenericContext::from_typevar_instances(db, [element_ty])), Parameters::new( db, [Parameter::positional_only(Some(Name::new_static( "iterable", ))) .with_annotated_type( - KnownClass::Iterable.to_specialized_instance(db, [object]), + KnownClass::Iterable.to_specialized_instance( + db, + [Type::TypeVar(element_ty)], + ), )], ), - Some(Type::homogeneous_tuple(db, object)), + Some(Type::homogeneous_tuple(db, Type::TypeVar(element_ty))), ), ], ) @@ -6486,7 +6507,11 @@ impl<'db> Type<'db> { } Type::DataclassDecorator(_) => { - let typevar = BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Invariant); + let typevar = BoundTypeVarInstance::synthetic( + db, + Name::new_static("T"), + TypeVarVariance::Invariant, + ); let typevar_meta = SubclassOfType::from(db, typevar); let context = GenericContext::from_typevar_instances(db, [typevar]); let parameters = [Parameter::positional_only(Some(Name::new_static("cls"))) @@ -7866,6 +7891,7 @@ impl<'db> Type<'db> { } TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) | + TypeMapping::UniqueSpecialization { .. } | TypeMapping::PromoteLiterals(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } | @@ -8029,7 +8055,13 @@ impl<'db> Type<'db> { // Do not call `value_type` here. `value_type` does the specialization internally, so `apply_type_mapping` is performed without `visitor` inheritance. // In the case of recursive type aliases, this leads to infinite recursion. // Instead, call `raw_value_type` and perform the specialization after the `visitor` cache has been created. - let value_type = visitor.visit(self, || alias.raw_value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor)); + let value_type = visitor.visit(self, || { + match type_mapping { + TypeMapping::UniqueSpecialization { .. } => alias.raw_value_type(db), + _ => alias.raw_value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor), + } + }); + alias.apply_function_specialization(db, value_type).apply_type_mapping_impl(db, type_mapping, tcx, visitor) } @@ -8042,6 +8074,7 @@ impl<'db> Type<'db> { | Type::EnumLiteral(_) => match type_mapping { TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) | + TypeMapping::UniqueSpecialization { .. } | TypeMapping::BindLegacyTypevars(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } | @@ -8055,6 +8088,7 @@ impl<'db> Type<'db> { Type::Dynamic(_) => match type_mapping { TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) | + TypeMapping::UniqueSpecialization { .. } | TypeMapping::BindLegacyTypevars(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } | @@ -8760,12 +8794,17 @@ impl PromoteLiteralsMode { /// This is represented as an enum (with some variants using `Cow`), and not an `FnMut` trait, /// since we sometimes have to apply type mappings lazily (e.g., to the signature of a function /// literal). -#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] +#[derive(Clone, Debug, Eq, PartialEq, get_size2::GetSize)] pub enum TypeMapping<'a, 'db> { /// Applies a specialization to the type Specialization(Specialization<'db>), /// Applies a partial specialization to the type PartialSpecialization(PartialSpecialization<'a, 'db>), + /// Resets any specializations to contain unique synthetic type variables. + UniqueSpecialization { + // A list of synthetic type variables, and the types they replaced. + specialization: RefCell, Type<'db>)>>, + }, /// Replaces any literal types with their corresponding promoted type form (e.g. `Literal["string"]` /// to `str`, or `def _() -> int` to `Callable[[], int]`). PromoteLiterals(PromoteLiteralsMode), @@ -8799,6 +8838,7 @@ impl<'db> TypeMapping<'_, 'db> { match self { TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) + | TypeMapping::UniqueSpecialization { .. } | TypeMapping::PromoteLiterals(_) | TypeMapping::BindLegacyTypevars(_) | TypeMapping::Materialize(_) @@ -8833,6 +8873,7 @@ impl<'db> TypeMapping<'_, 'db> { TypeMapping::PromoteLiterals(mode) => TypeMapping::PromoteLiterals(mode.flip()), TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) + | TypeMapping::UniqueSpecialization { .. } | TypeMapping::BindLegacyTypevars(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } @@ -10308,14 +10349,10 @@ impl<'db> BoundTypeVarInstance<'db> { /// Create a new PEP 695 type variable that can be used in signatures /// of synthetic generic functions. - pub(crate) fn synthetic( - db: &'db dyn Db, - name: &'static str, - variance: TypeVarVariance, - ) -> Self { + pub(crate) fn synthetic(db: &'db dyn Db, name: Name, variance: TypeVarVariance) -> Self { let identity = TypeVarIdentity::new( db, - Name::new_static(name), + name, None, // definition TypeVarKind::Pep695, ); @@ -10464,7 +10501,8 @@ impl<'db> BoundTypeVarInstance<'db> { Type::TypeVar(self) } } - TypeMapping::PromoteLiterals(_) + TypeMapping::UniqueSpecialization { .. } + | TypeMapping::PromoteLiterals(_) | TypeMapping::ReplaceParameterDefaults | TypeMapping::BindLegacyTypevars(_) | TypeMapping::EagerExpansion => Type::TypeVar(self), diff --git a/crates/ty_python_semantic/src/types/bound_super.rs b/crates/ty_python_semantic/src/types/bound_super.rs index 99d0ce50fe..f179aefea9 100644 --- a/crates/ty_python_semantic/src/types/bound_super.rs +++ b/crates/ty_python_semantic/src/types/bound_super.rs @@ -331,7 +331,7 @@ impl<'db> BoundSuperType<'db> { Type::NominalInstance(instance) => SuperOwnerKind::Instance(instance), Type::ProtocolInstance(protocol) => { - if let Some(nominal_instance) = protocol.as_nominal_type() { + if let Some(nominal_instance) = protocol.to_nominal_instance() { SuperOwnerKind::Instance(nominal_instance) } else { return Err(BoundSuperError::AbstractOwnerType { diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index b0134a20c3..97e47b51fb 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3017,7 +3017,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { .class_specialization(self.db)?; builder - .infer_map(return_ty, tcx, |(_, variance, inferred_ty)| { + .infer_reverse_map(tcx, return_ty, |(_, variance, inferred_ty)| { // Avoid unnecessarily widening the return type based on a covariant // type parameter from the type context, as it can lead to argument // assignability errors if the type variable is constrained by a narrower @@ -3029,11 +3029,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { Some(inferred_ty) }) .ok()?; + Some(builder.type_mappings().clone()) }); - // For a given type variable, we track the variance of any assignments to that type variable - // in the argument types. + // For a given type variable, we keep track of the variance of any assignments to + // that type variable in the type context. let mut variance_in_arguments: FxHashMap, TypeVarVariance> = FxHashMap::default(); @@ -3120,7 +3121,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { return ty; } - // If the type variable is a non-covariant position in the argument, then we avoid + // If the type variable is a non-covariant position in any argument, then we avoid // promotion, respecting any literals in the parameter type. if variance_in_arguments .get(&identity) diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index dd6b192d2d..36b9109126 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -2823,8 +2823,11 @@ impl<'db> ClassLiteral<'db> { }), ); - let t_default = - BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); + let t_default = BoundTypeVarInstance::synthetic( + db, + Name::new_static("T"), + TypeVarVariance::Covariant, + ); let get_with_default_sig = Signature::new_generic( Some(GenericContext::from_typevar_instances(db, [t_default])), @@ -2870,8 +2873,11 @@ impl<'db> ClassLiteral<'db> { ) })) .chain(std::iter::once({ - let t_default = - BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); + let t_default = BoundTypeVarInstance::synthetic( + db, + Name::new_static("T"), + TypeVarVariance::Covariant, + ); Signature::new_generic( Some(GenericContext::from_typevar_instances(db, [t_default])), @@ -2928,8 +2934,11 @@ impl<'db> ClassLiteral<'db> { ); // `.pop()` with a default value - let t_default = - BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); + let t_default = BoundTypeVarInstance::synthetic( + db, + Name::new_static("T"), + TypeVarVariance::Covariant, + ); let pop_with_default_sig = Signature::new_generic( Some(GenericContext::from_typevar_instances(db, [t_default])), diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index bd78d94c9b..0764582753 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -3964,6 +3964,7 @@ mod tests { use crate::db::tests::setup_db; use crate::types::{BoundTypeVarInstance, KnownClass, TypeVarVariance}; + use ruff_python_ast::name::Name; #[test] fn test_display_graph_output() { @@ -3997,8 +3998,10 @@ mod tests { .trim_end(); let db = setup_db(); - let t = BoundTypeVarInstance::synthetic(&db, "T", TypeVarVariance::Invariant); - let u = BoundTypeVarInstance::synthetic(&db, "U", TypeVarVariance::Invariant); + let t = + BoundTypeVarInstance::synthetic(&db, Name::new_static("T"), TypeVarVariance::Invariant); + let u = + BoundTypeVarInstance::synthetic(&db, Name::new_static("U"), TypeVarVariance::Invariant); let bool_type = KnownClass::Bool.to_instance(&db); let str_type = KnownClass::Str.to_instance(&db); let t_str = ConstraintSet::range(&db, str_type, t, str_type); diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 36e89553a0..dcfc9e70c3 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -4,6 +4,7 @@ use std::fmt::Display; use itertools::{Either, Itertools}; use ruff_python_ast as ast; +use ruff_python_ast::name::Name; use rustc_hash::{FxHashMap, FxHashSet}; use crate::semantic_index::definition::Definition; @@ -1046,20 +1047,39 @@ impl<'db> Specialization<'db> { return self.materialize_impl(db, *materialization_kind, visitor); } - let types: Box<[_]> = self - .types(db) - .iter() - .zip(self.generic_context(db).variables(db)) - .enumerate() - .map(|(i, (ty, typevar))| { - let tcx = TypeContext::new(tcx.get(i).copied()); - if typevar.variance(db).is_covariant() { - ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor) - } else { - ty.apply_type_mapping_impl(db, &type_mapping.flip(), tcx, visitor) - } - }) - .collect(); + let types: Box<[_]> = if let TypeMapping::UniqueSpecialization { specialization } = + type_mapping + { + let mut specialization = specialization.borrow_mut(); + + self.types(db) + .iter() + .zip(self.generic_context(db).variables(db)) + .map(|(ty, typevar)| { + // Create a unique synthetic type variable. + let name = format!("_T{}", specialization.len()); + let synthetic = + BoundTypeVarInstance::synthetic(db, Name::new(name), typevar.variance(db)); + + specialization.push((synthetic, *ty)); + Type::TypeVar(synthetic) + }) + .collect() + } else { + self.types(db) + .iter() + .zip(self.generic_context(db).variables(db)) + .enumerate() + .map(|(i, (ty, typevar))| { + let tcx = TypeContext::new(tcx.get(i).copied()); + if typevar.variance(db).is_covariant() { + ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor) + } else { + ty.apply_type_mapping_impl(db, &type_mapping.flip(), tcx, visitor) + } + }) + .collect() + }; let tuple_inner = self.tuple_inner(db).and_then(|tuple| { tuple.apply_type_mapping_impl(db, type_mapping, TypeContext::default(), visitor) @@ -1505,6 +1525,11 @@ impl<'db> SpecializationBuilder<'db> { &self.types } + /// Returns the current set of type mappings for this specialization. + pub(crate) fn into_type_mappings(self) -> FxHashMap, Type<'db>> { + self.types + } + /// Map the types that have been assigned in this specialization. pub(crate) fn mapped( &self, @@ -1975,6 +2000,90 @@ impl<'db> SpecializationBuilder<'db> { Ok(()) } + + /// Infer type mappings for the specialization in the reverse direction, i.e., where the + /// actual type, not the formal type, contains inferable type variables. + pub(crate) fn infer_reverse( + &mut self, + formal: Type<'db>, + actual: Type<'db>, + ) -> Result<(), SpecializationError<'db>> { + self.infer_reverse_map(formal, actual, |(_, _, ty)| Some(ty)) + } + + /// Infer type mappings for the specialization in the reverse direction, i.e., where the + /// actual type, not the formal type, contains inferable type variables. + /// + /// The provided function will be called before any type mappings are created, and can + /// optionally modify the inferred type, or filter out the type mapping entirely. + pub(crate) fn infer_reverse_map( + &mut self, + formal: Type<'db>, + actual: Type<'db>, + mut f: impl FnMut(TypeVarAssignment<'db>) -> Option>, + ) -> Result<(), SpecializationError<'db>> { + self.infer_reverse_map_impl(formal, actual, TypeVarVariance::Covariant, &mut f) + } + + fn infer_reverse_map_impl( + &mut self, + formal: Type<'db>, + actual: Type<'db>, + polarity: TypeVarVariance, + f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option>, + ) -> Result<(), SpecializationError<'db>> { + // Assign each type variable on the formal type to a unique synthetic type variable. + let type_mapping = TypeMapping::UniqueSpecialization { + specialization: RefCell::new(Vec::new()), + }; + let synthetic_formal = + formal.apply_type_mapping(self.db, &type_mapping, TypeContext::default()); + + // Recover the synthetic type variables. + let synthetic_specialization = match type_mapping { + TypeMapping::UniqueSpecialization { specialization } => specialization.into_inner(), + _ => unreachable!(), + }; + + let inferable = GenericContext::from_typevar_instances( + self.db, + synthetic_specialization.iter().map(|(typevar, _)| *typevar), + ) + .inferable_typevars(self.db); + + // Collect the actual type to which each synthetic type variable is mapped. + let forward_type_mappings = { + let mut builder = SpecializationBuilder::new(self.db, inferable); + builder.infer(synthetic_formal, actual)?; + builder.into_type_mappings() + }; + + // If there are no forward type mappings, try the other direction. + // + // This is the base case for when `actual` is an inferable type variable. + if forward_type_mappings.is_empty() { + return self.infer_map_impl(actual, formal, polarity, f); + } + + // Consider the reverse inference of `Sequence[int]` given `list[T]`. + // + // Given a forward type mapping of `T@Sequence` -> `T@list`, and a synthetic type mapping of + // `T@Sequence` -> `int`, we want to infer the reverse type mapping `T@list` -> `int`. + for (synthetic_type_var_identity, actual_type) in forward_type_mappings { + if let Some((synthetic_type_var, formal_type)) = synthetic_specialization + .iter() + .find(|(typevar, _)| synthetic_type_var_identity == typevar.identity(self.db)) + { + let variance = synthetic_type_var.variance_with_polarity(self.db, polarity); + + // Note that it is possible that we need to recurse deeper, so we continue + // to perform a reverse inference on the nested types. + self.infer_reverse_map_impl(*formal_type, actual_type, variance, f)?; + } + } + + Ok(()) + } } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 5657653548..c54324ffc9 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -104,15 +104,16 @@ use crate::types::typed_dict::{ }; use crate::types::visitor::any_over_type; use crate::types::{ - BoundTypeVarInstance, CallDunderError, CallableBinding, CallableType, CallableTypeKind, - ClassLiteral, ClassType, DataclassParams, DynamicType, InternedType, IntersectionBuilder, - IntersectionType, KnownClass, KnownInstanceType, KnownUnion, LintDiagnosticGuard, - MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, ParamSpecAttrKind, Parameter, - ParameterForm, Parameters, Signature, SpecialFormType, SubclassOfType, TrackedConstraintSet, - Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, - TypeVarBoundOrConstraints, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, - TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, - UnionType, UnionTypeInstance, binding_type, infer_scope_types, todo_type, + BoundTypeVarIdentity, BoundTypeVarInstance, CallDunderError, CallableBinding, CallableType, + CallableTypeKind, ClassLiteral, ClassType, DataclassParams, DynamicType, InternedType, + IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, KnownUnion, + LintDiagnosticGuard, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, + ParamSpecAttrKind, Parameter, ParameterForm, Parameters, Signature, SpecialFormType, + SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, + TypeContext, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarBoundOrConstraintsEvaluation, + TypeVarDefaultEvaluation, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, + TypedDictType, UnionBuilder, UnionType, UnionTypeInstance, binding_type, infer_scope_types, + todo_type, }; use crate::types::{CallableTypes, overrides}; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; @@ -7849,16 +7850,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { { // Extract the type variable `T` from `list[T]` in typeshed. let elt_tys = |collection_class: KnownClass| { - let class_literal = collection_class.try_to_class_literal(self.db())?; - let generic_context = class_literal.generic_context(self.db())?; + let collection_alias = collection_class + .try_to_class_literal(self.db())? + .identity_specialization(self.db()) + .into_generic_alias()?; + + let generic_context = collection_alias + .specialization(self.db()) + .generic_context(self.db()); + Some(( - class_literal, + collection_alias, generic_context, generic_context.variables(self.db()), )) }; - let Some((class_literal, generic_context, elt_tys)) = elt_tys(collection_class) else { + let Some((collection_alias, generic_context, elt_tys)) = elt_tys(collection_class) else { // Infer the element types without type context, and fallback to unknown for // custom typesheds. for elt in elts.flatten().flatten() { @@ -7879,41 +7887,80 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { annotation.filter_disjoint_elements(self.db(), collection_ty, inferable) }); - // Extract the annotated type of `T`, if provided. - let annotated_elt_tys = tcx - .known_specialization(self.db(), collection_class) - .map(|specialization| specialization.types(self.db())); + // Collect type constraints from the declared element types. + let (elt_tcx_constraints, elt_tcx_variance) = { + let mut builder = SpecializationBuilder::new( + self.db(), + generic_context.inferable_typevars(self.db()), + ); + + // For a given type variable, we keep track of the variance of any assignments to + // that type variable in the type context. + let mut elt_tcx_variance: FxHashMap, TypeVarVariance> = + FxHashMap::default(); + + if let Some(tcx) = tcx.annotation + // If there are multiple potential type contexts, we fallback to `Unknown`. + // TODO: We could perform multi-inference here. + && tcx + .filter_union(self.db(), |ty| ty.class_specialization(self.db()).is_some()) + .class_specialization(self.db()) + .is_some() + { + let collection_instance = + Type::instance(self.db(), ClassType::Generic(collection_alias)); + + builder + .infer_reverse_map( + tcx, + collection_instance, + |(typevar, variance, inferred_ty)| { + elt_tcx_variance + .entry(typevar) + .and_modify(|current| *current = current.join(variance)) + .or_insert(variance); + + Some(inferred_ty) + }, + ) + .ok()?; + } + + (builder.into_type_mappings(), elt_tcx_variance) + }; // Create a set of constraints to infer a precise type for `T`. let mut builder = SpecializationBuilder::new(self.db(), inferable); - match annotated_elt_tys { - // The annotated type acts as a constraint for `T`. + for elt_ty in elt_tys.clone() { + let elt_ty_identity = elt_ty.identity(self.db()); + let elt_tcx = elt_tcx_constraints + // The annotated type acts as a constraint for `T`. + // + // Note that we infer the annotated type _before_ the elements, to more closely match + // the order of any unions as written in the type annotation. + .get(&elt_ty_identity) + .copied(); + + // Avoid unnecessarily widening the return type based on a covariant + // type parameter from the type context. // - // Note that we infer the annotated type _before_ the elements, to more closely match the - // order of any unions as written in the type annotation. - Some(annotated_elt_tys) => { - for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys.clone(), annotated_elt_tys) { - builder - .infer(Type::TypeVar(elt_ty), *annotated_elt_ty) - .ok()?; - } + // Note that we also avoid unioning the inferred type with `Unknown` in this + // case, which is only necessary for invariant collections. + if elt_tcx_variance + .get(&elt_ty_identity) + .is_some_and(|variance| variance.is_covariant()) + { + continue; } - // If a valid type annotation was not provided, avoid restricting the type of the collection - // by unioning the inferred type with `Unknown`. - None => { - for elt_ty in elt_tys.clone() { - builder.infer(Type::TypeVar(elt_ty), Type::unknown()).ok()?; - } - } + // If a valid type annotation was not provided, avoid restricting the type of the + // collection by unioning the inferred type with `Unknown`. + let elt_tcx = elt_tcx.unwrap_or(Type::unknown()); + + builder.infer(Type::TypeVar(elt_ty), elt_tcx).ok()?; } - let elt_tcxs = match annotated_elt_tys { - None => Either::Left(iter::repeat(TypeContext::default())), - Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))), - }; - for elts in elts { // An unpacking expression for a dictionary. if let &[None, Some(value)] = elts.as_slice() { @@ -7936,14 +7983,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // The inferred type of each element acts as an additional constraint on `T`. - for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone()) - { + for (elt, elt_ty) in iter::zip(elts, elt_tys.clone()) { let Some(elt) = elt else { continue }; - let inferred_elt_ty = infer_elt_expression(self, elt, elt_tcx); + // Note that unlike when preferring the declared type, we use covariant type + // assignments from the type context to potentially _narrow_ the inferred type, + // by avoiding literal promotion. + let elt_ty_identity = elt_ty.identity(self.db()); + let elt_tcx = elt_tcx_constraints.get(&elt_ty_identity).copied(); - // Simplify the inference based on the declared type of the element. - if let Some(elt_tcx) = elt_tcx.annotation { + let inferred_elt_ty = infer_elt_expression(self, elt, TypeContext::new(elt_tcx)); + + // Simplify the inference based on a non-covariant declared type. + if let Some(elt_tcx) = + elt_tcx.filter(|_| !elt_tcx_variance[&elt_ty_identity].is_covariant()) + { if inferred_elt_ty.is_assignable_to(self.db(), elt_tcx) { continue; } @@ -7951,14 +8005,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Convert any element literals to their promoted type form to avoid excessively large // unions for large nested list literals, which the constraint solver struggles with. - let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db(), elt_tcx); + let inferred_elt_ty = + inferred_elt_ty.promote_literals(self.db(), TypeContext::new(elt_tcx)); builder.infer(Type::TypeVar(elt_ty), inferred_elt_ty).ok()?; } } - let class_type = - class_literal.apply_specialization(self.db(), |_| builder.build(generic_context)); + let class_type = collection_alias + .origin(self.db()) + .apply_specialization(self.db(), |_| builder.build(generic_context)); Type::from(class_type).to_instance(self.db()) } @@ -8445,7 +8501,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { call_expression: &ast::ExprCall, tcx: TypeContext<'db>, ) -> Type<'db> { - // TODO: Use the type context for more precise inference. let callable_type = self.infer_maybe_standalone_expression(&call_expression.func, TypeContext::default()); diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 9e674065b9..fbad5fbcdb 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -165,7 +165,7 @@ impl<'db> Type<'db> { // This matches the behaviour of other type checkers, and is required for us to // recognise `str` as a subtype of `Container[str]`. structurally_satisfied.or(db, || { - let Some(nominal_instance) = protocol.as_nominal_type() else { + let Some(nominal_instance) = protocol.to_nominal_instance() else { return ConstraintSet::from(false); }; @@ -175,7 +175,7 @@ impl<'db> Type<'db> { // `Q`'s members in a Liskov-incompatible way. let type_to_test = self .as_protocol_instance() - .and_then(ProtocolInstanceType::as_nominal_type) + .and_then(ProtocolInstanceType::to_nominal_instance) .map(Type::NominalInstance) .unwrap_or(self); @@ -658,7 +658,7 @@ impl<'db> ProtocolInstanceType<'db> { /// If this is a synthesized protocol that does not correspond to a class definition /// in source code, return `None`. These are "pure" abstract types, that cannot be /// treated in a nominal way. - pub(super) fn as_nominal_type(self) -> Option> { + pub(super) fn to_nominal_instance(self) -> Option> { match self.inner { Protocol::FromClass(class) => { Some(NominalInstanceType(NominalInstanceInner::NonTuple(*class)))