From 78f21945c2bf6f0863fbc4c253d936922f72513e Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Tue, 2 Dec 2025 00:09:39 -0500 Subject: [PATCH] subtyping for bidirectional inference --- .../mdtest/assignment/annotations.md | 42 +++++ .../resources/mdtest/literal_promotion.md | 55 ++++++ .../resources/mdtest/type_compendium/tuple.md | 3 + crates/ty_python_semantic/src/types.rs | 157 ++++++++++++------ .../src/types/bound_super.rs | 2 +- .../ty_python_semantic/src/types/call/bind.rs | 2 +- crates/ty_python_semantic/src/types/cyclic.rs | 10 +- .../ty_python_semantic/src/types/generics.rs | 58 +++++++ .../src/types/infer/builder.rs | 84 +++++----- .../ty_python_semantic/src/types/instance.rs | 6 +- 10 files changed, 323 insertions(+), 96 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 36f53afe4d..d95cc52fd1 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -600,6 +600,48 @@ reveal_type(x7) # revealed: Contravariant[Any] reveal_type(x8) # revealed: Invariant[Any] ``` +## Declared type preference sees through subtyping + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Any, Iterable, Literal, MutableSequence, Sequence + +x1: Sequence[Any] = [1, 2, 3] +reveal_type(x1) # revealed: list[Any] + +x2: MutableSequence[Any] = [1, 2, 3] +reveal_type(x2) # revealed: list[Any] + +x3: Iterable[Any] = [1, 2, 3] +reveal_type(x3) # revealed: list[Any] + +class X[T]: + value: T + + def __init__(self, value: T): ... + +class A[T](X[T]): ... + +def a[T](value: T) -> A[T]: + return A(value) + +x4: A[object] = A(1) +reveal_type(x4) # revealed: A[object] + +x5: X[object] = A(1) +reveal_type(x5) # revealed: A[object] + +x6: X[object] | None = A(1) +reveal_type(x6) # revealed: A[object] + +x7: X[object] | None = a(1) +reveal_type(x7) # revealed: A[object] +``` + ## Narrow generic unions ```toml diff --git a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md index eb79c44b6c..61d08bd676 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md +++ b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md @@ -341,3 +341,58 @@ reveal_type(x21) # revealed: X[Literal[1]] x22: X[Literal[1]] | None = x(1) reveal_type(x22) # revealed: X[Literal[1]] ``` + +## Literal annotations see through subtyping + +```py +from typing import Any, Iterable, Literal, MutableSequence, Sequence + +x1: Sequence[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x1) # revealed: list[Literal[1, 2, 3]] + +x2: MutableSequence[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x2) # revealed: list[Literal[1, 2, 3]] + +x3: Iterable[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x3) # revealed: list[Literal[1, 2, 3]] + +class Sup1[T]: + value: T + +class Sub1[T](Sup1[T]): ... + +def sub1[T](value: T) -> Sub1[T]: + return Sub1() + +x4: Sub1[Literal[1]] = sub1(1) +reveal_type(x4) # revealed: Sub1[Literal[1]] + +x5: Sup1[Literal[1]] = sub1(1) +reveal_type(x5) # revealed: Sub1[Literal[1]] + +x6: Sup1[Literal[1]] | None = sub1(1) +reveal_type(x6) # revealed: Sub1[Literal[1]] + +x7: Sup1[Literal[1]] | None = sub1(1) +reveal_type(x7) # revealed: Sub1[Literal[1]] + +class Sup2A[T, U]: + value: tuple[T, U] + +class Sup2B[T, U]: + value: tuple[T, U] + +class Sub2[T, U](Sup2A[T, Any], Sup2B[Any, U]): ... + +def sub2[T, U](x: T, y: U) -> Sub2[T, U]: + return Sub2() + +x8 = sub2(1, 2) +reveal_type(x8) # revealed: Sub2[int, int] + +x9: Sup2A[Literal[1], Literal[2]] = sub2(1, 2) +reveal_type(x9) # revealed: Sub2[Literal[1], int] + +x10: Sup2B[Literal[1], Literal[2]] = sub2(1, 2) +reveal_type(x10) # revealed: Sub2[int, Literal[2]] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md index e323d25a17..e2c45cc7f1 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md +++ b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md @@ -57,6 +57,9 @@ reveal_type(tuple((1, 2))) # revealed: tuple[Literal[1], Literal[2]] reveal_type(tuple([1])) # revealed: tuple[Unknown | int, ...] +x1: tuple[int, ...] = tuple([1]) +reveal_type(x1) # revealed: tuple[int, ...] + # error: [invalid-argument-type] reveal_type(tuple[int]([1])) # revealed: tuple[int] diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index f531fb604e..e13487676a 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -61,8 +61,8 @@ use crate::types::function::{ }; pub(crate) use crate::types::generics::GenericContext; use crate::types::generics::{ - InferableTypeVars, PartialSpecialization, Specialization, bind_typevar, typing_self, - walk_generic_context, + InferableTypeVars, PartialSpecialization, Specialization, SpecializationBuilder, bind_typevar, + typing_self, walk_generic_context, }; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; @@ -1019,7 +1019,10 @@ impl<'db> Type<'db> { } /// If this type is a class instance, returns its specialization. - pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option> { + pub(crate) fn class_specialization( + self, + db: &'db dyn Db, + ) -> Option<(ClassLiteral<'db>, Specialization<'db>)> { self.specialization_of_optional(db, None) } @@ -1030,15 +1033,17 @@ impl<'db> Type<'db> { expected_class: ClassLiteral<'_>, ) -> Option> { self.specialization_of_optional(db, Some(expected_class)) + .map(|(_, specialization)| specialization) } fn specialization_of_optional( self, db: &'db dyn Db, expected_class: Option>, - ) -> Option> { + ) -> Option<(ClassLiteral<'db>, Specialization<'db>)> { let class_type = match self { Type::NominalInstance(instance) => instance, + Type::ProtocolInstance(instance) => instance.to_nominal_instance()?, Type::TypeAlias(alias) => alias.value_type(db).as_nominal_instance()?, _ => return None, } @@ -1049,7 +1054,7 @@ impl<'db> Type<'db> { return None; } - specialization + Some((class_literal, specialization?)) } /// Returns the top materialization (or upper bound materialization) of this type, which is the @@ -3916,69 +3921,110 @@ impl<'db> Type<'db> { where F: FnMut(BoundTypeVarInstance<'db>, Type<'db>, TypeVarVariance, TypeContext<'db>), { - self.visit_specialization_impl( + let try_visit = &mut |type_var, ty, variance, tcx| -> Result<(), ()> { + f(type_var, ty, variance, tcx); + Ok(()) + }; + + let _ = self.try_visit_specialization(db, tcx, try_visit); + } + + pub(crate) fn try_visit_specialization( + self, + db: &'db dyn Db, + tcx: TypeContext<'db>, + mut f: F, + ) -> Result<(), E> + where + F: FnMut( + BoundTypeVarInstance<'db>, + Type<'db>, + TypeVarVariance, + TypeContext<'db>, + ) -> Result<(), E>, + { + self.try_visit_specialization_impl( db, tcx, TypeVarVariance::Covariant, &mut f, &SpecializationVisitor::default(), - ); + ) } - fn visit_specialization_impl( + fn try_visit_specialization_impl( self, db: &'db dyn Db, tcx: TypeContext<'db>, polarity: TypeVarVariance, - f: &mut dyn FnMut(BoundTypeVarInstance<'db>, Type<'db>, TypeVarVariance, TypeContext<'db>), + f: &mut dyn FnMut( + BoundTypeVarInstance<'db>, + Type<'db>, + TypeVarVariance, + TypeContext<'db>, + ) -> Result<(), E>, visitor: &SpecializationVisitor<'db>, - ) { - let Type::NominalInstance(instance) = self else { - match self { - Type::Union(union) => { - for element in union.elements(db) { - element.visit_specialization_impl(db, tcx, polarity, f, visitor); - } + ) -> Result<(), E> { + let instance = match self { + Type::Union(union) => { + for element in union.elements(db) { + element.try_visit_specialization_impl(db, tcx, polarity, f, visitor)?; } - Type::Intersection(intersection) => { - for element in intersection.positive(db) { - element.visit_specialization_impl(db, tcx, polarity, f, visitor); - } + return Ok(()); + } + Type::Intersection(intersection) => { + for element in intersection.positive(db) { + element.try_visit_specialization_impl(db, tcx, polarity, f, visitor)?; } - Type::TypeAlias(alias) => visitor.visit(self, || { + return Ok(()); + } + Type::TypeAlias(alias) => { + visitor.try_visit(self, || { alias .value_type(db) - .visit_specialization_impl(db, tcx, polarity, f, visitor); - }), - _ => {} - } + .try_visit_specialization_impl(db, tcx, polarity, f, visitor) + })?; - return; + return Ok(()); + } + Type::NominalInstance(instance) => instance, + Type::ProtocolInstance(protocol) => match protocol.to_nominal_instance() { + Some(instance) => instance, + None => return Ok(()), + }, + _ => return Ok(()), }; let (class_literal, Some(specialization)) = instance.class(db).class_literal(db) else { - return; + return Ok(()); + }; + let generic_context = specialization.generic_context(db); + + // Collect the type mappings used to narrow the type context. + let tcx_mappings = { + let mut builder = + SpecializationBuilder::new(db, generic_context.inferable_typevars(db)); + + if let Some(tcx) = tcx.annotation { + let alias_instance = Type::instance(db, class_literal.identity_specialization(db)); + let _ = builder.infer_reverse(tcx, alias_instance); + } + + builder.into_type_mappings() }; - let tcx_specialization = tcx.annotation.and_then(|tcx| { - tcx.filter_union(db, |ty| ty.specialization_of(db, class_literal).is_some()) - .specialization_of(db, class_literal) - }); + for (type_var, ty) in generic_context.variables(db).zip(specialization.types(db)) { + let variance = type_var.variance_with_polarity(db, polarity); + let narrowed_tcx = TypeContext::new(tcx_mappings.get(&type_var.identity(db)).copied()); - for (typevar, ty) in specialization - .generic_context(db) - .variables(db) - .zip(specialization.types(db)) - { - let variance = typevar.variance_with_polarity(db, polarity); - let tcx = TypeContext::new(tcx_specialization.and_then(|spec| spec.get(db, typevar))); + f(type_var, *ty, variance, narrowed_tcx)?; - f(typevar, *ty, variance, tcx); - - visitor.visit(*ty, || { - ty.visit_specialization_impl(db, tcx, variance, f, visitor); - }); + visitor.try_visit(*ty, || { + ty.try_visit_specialization_impl(db, narrowed_tcx, variance, f, visitor) + })?; } + + Ok(()) } /// Return true if there is just a single inhabitant for this type. @@ -6173,30 +6219,35 @@ impl<'db> Type<'db> { } Some(KnownClass::Tuple) => { - let object = Type::object(); + let element_ty = + BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); // ```py - // class tuple: + // class tuple(Sequence[_T_co]): // @overload // def __new__(cls) -> tuple[()]: ... // @overload - // def __new__(cls, iterable: Iterable[object]) -> tuple[object, ...]: ... + // def __new__(cls, iterable: Iterable[_T_co]) -> tuple[_T_co, ...]: ... // ``` CallableBinding::from_overloads( self, [ Signature::new(Parameters::empty(), Some(Type::empty_tuple(db))), - Signature::new( + Signature::new_generic( + Some(GenericContext::from_typevar_instances(db, [element_ty])), Parameters::new( db, [Parameter::positional_only(Some(Name::new_static( "iterable", ))) .with_annotated_type( - KnownClass::Iterable.to_specialized_instance(db, [object]), + KnownClass::Iterable.to_specialized_instance( + db, + [Type::TypeVar(element_ty)], + ), )], ), - Some(Type::homogeneous_tuple(db, object)), + Some(Type::homogeneous_tuple(db, Type::TypeVar(element_ty))), ), ], ) @@ -7702,6 +7753,7 @@ impl<'db> Type<'db> { } TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) | + TypeMapping::IdentitySpecialization | TypeMapping::PromoteLiterals(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } | @@ -7878,6 +7930,7 @@ impl<'db> Type<'db> { | Type::EnumLiteral(_) => match type_mapping { TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) | + TypeMapping::IdentitySpecialization | TypeMapping::BindLegacyTypevars(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } | @@ -7891,6 +7944,7 @@ impl<'db> Type<'db> { Type::Dynamic(_) => match type_mapping { TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) | + TypeMapping::IdentitySpecialization | TypeMapping::BindLegacyTypevars(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } | @@ -8601,6 +8655,8 @@ pub enum TypeMapping<'a, 'db> { Specialization(Specialization<'db>), /// Applies a partial specialization to the type PartialSpecialization(PartialSpecialization<'a, 'db>), + /// Resets any specializations to their identity. + IdentitySpecialization, /// Replaces any literal types with their corresponding promoted type form (e.g. `Literal["string"]` /// to `str`, or `def _() -> int` to `Callable[[], int]`). PromoteLiterals(PromoteLiteralsMode), @@ -8634,6 +8690,7 @@ impl<'db> TypeMapping<'_, 'db> { match self { TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) + | TypeMapping::IdentitySpecialization | TypeMapping::PromoteLiterals(_) | TypeMapping::BindLegacyTypevars(_) | TypeMapping::Materialize(_) @@ -8668,6 +8725,7 @@ impl<'db> TypeMapping<'_, 'db> { TypeMapping::PromoteLiterals(mode) => TypeMapping::PromoteLiterals(mode.flip()), TypeMapping::Specialization(_) | TypeMapping::PartialSpecialization(_) + | TypeMapping::IdentitySpecialization | TypeMapping::BindLegacyTypevars(_) | TypeMapping::BindSelf { .. } | TypeMapping::ReplaceSelf { .. } @@ -10225,6 +10283,7 @@ impl<'db> BoundTypeVarInstance<'db> { }) .unwrap_or(Type::TypeVar(self)) } + TypeMapping::IdentitySpecialization => Type::TypeVar(self), TypeMapping::PartialSpecialization(partial) => { let typevar = if self.is_paramspec(db) { self.without_paramspec_attr(db) diff --git a/crates/ty_python_semantic/src/types/bound_super.rs b/crates/ty_python_semantic/src/types/bound_super.rs index 442ae0d0b9..11f99d9f27 100644 --- a/crates/ty_python_semantic/src/types/bound_super.rs +++ b/crates/ty_python_semantic/src/types/bound_super.rs @@ -321,7 +321,7 @@ impl<'db> BoundSuperType<'db> { Type::NominalInstance(instance) => SuperOwnerKind::Instance(instance), Type::ProtocolInstance(protocol) => { - if let Some(nominal_instance) = protocol.as_nominal_type() { + if let Some(nominal_instance) = protocol.to_nominal_instance() { SuperOwnerKind::Instance(nominal_instance) } else { return Err(BoundSuperError::AbstractOwnerType { diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index e81d26d8b8..5f234b2ce3 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3004,7 +3004,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()) .class_specialization(self.db)?; - builder.infer(return_ty, tcx).ok()?; + builder.infer_reverse(tcx, return_ty).ok()?; Some(builder.type_mappings().clone()) }); diff --git a/crates/ty_python_semantic/src/types/cyclic.rs b/crates/ty_python_semantic/src/types/cyclic.rs index 6f179b1a72..151222c721 100644 --- a/crates/ty_python_semantic/src/types/cyclic.rs +++ b/crates/ty_python_semantic/src/types/cyclic.rs @@ -122,14 +122,14 @@ impl CycleDetector { ret } - pub fn try_visit(&self, item: T, func: impl FnOnce() -> Option) -> Option { + pub fn try_visit(&self, item: T, func: impl FnOnce() -> Result) -> Result { if let Some(val) = self.cache.borrow().get(&item) { - return Some(val.clone()); + return Ok(val.clone()); } // We hit a cycle if !self.seen.borrow_mut().insert(item.clone()) { - return Some(self.fallback.clone()); + return Ok(self.fallback.clone()); } // Check depth limit to prevent stack overflow from recursive generic protocols @@ -137,7 +137,7 @@ impl CycleDetector { let current_depth = self.depth.get(); if current_depth >= MAX_RECURSION_DEPTH { self.seen.borrow_mut().pop(); - return Some(self.fallback.clone()); + return Ok(self.fallback.clone()); } self.depth.set(current_depth + 1); @@ -147,7 +147,7 @@ impl CycleDetector { self.seen.borrow_mut().pop(); self.cache.borrow_mut().insert(item, ret.clone()); - Some(ret) + Ok(ret) } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 58ccd28ca0..13725d9358 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1014,6 +1014,10 @@ impl<'db> Specialization<'db> { return self.materialize_impl(db, *materialization_kind, visitor); } + if *type_mapping == TypeMapping::IdentitySpecialization { + return self.generic_context(db).identity_specialization(db); + } + let types: Box<[_]> = self .types(db) .iter() @@ -1467,6 +1471,11 @@ impl<'db> SpecializationBuilder<'db> { &self.types } + /// Returns the current set of type mappings for this specialization. + pub(crate) fn into_type_mappings(self) -> FxHashMap, Type<'db>> { + self.types + } + /// Map the types that have been assigned in this specialization. pub(crate) fn mapped( &self, @@ -1852,12 +1861,61 @@ impl<'db> SpecializationBuilder<'db> { } } + (_, Type::TypeAlias(alias)) => { + return self.infer_map_impl(formal, alias.value_type(self.db), polarity, f); + } + // TODO: Add more forms that we can structurally induct into: type[C], callables _ => {} } Ok(()) } + + /// Infer type mappings for the specialization in the reverse direction, i.e., where the given type contains + /// inferable type variables. + pub(crate) fn infer_reverse( + &mut self, + formal: Type<'db>, + actual: Type<'db>, + ) -> Result<(), SpecializationError<'db>> { + let identity_formal = formal.apply_type_mapping( + self.db, + &TypeMapping::IdentitySpecialization, + TypeContext::default(), + ); + + // Collect all type variables on the actual type. + let mut formal_type_vars = Vec::new(); + formal.visit_specialization(self.db, TypeContext::default(), |typevar, _, _, _| { + formal_type_vars.push(typevar); + }); + + let inferable_type_vars = GenericContext::from_typevar_instances(self.db, formal_type_vars) + .inferable_typevars(self.db); + + // Perform type inference in the forward direction with the inferable identity types, + // collecting the forward type mappings. + let forward_type_mappings = { + let mut builder = SpecializationBuilder::new(self.db, inferable_type_vars); + builder.infer(identity_formal, actual)?; + builder.type_mappings().clone() + }; + + // If there are no forward type mappings, try the other direction. + if forward_type_mappings.is_empty() { + return self.infer(actual, formal); + } + + formal.try_visit_specialization(self.db, TypeContext::default(), |type_var, ty, _, _| { + // Reverse the type mappings and specialize them to their assigned types. + if let Some(formal) = forward_type_mappings.get(&type_var.identity(self.db)) { + self.infer(*formal, ty)?; + } + + Ok(()) + }) + } } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 00141d4ca8..dec104afb0 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -7694,16 +7694,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { { // Extract the type variable `T` from `list[T]` in typeshed. let elt_tys = |collection_class: KnownClass| { - let class_literal = collection_class.try_to_class_literal(self.db())?; - let generic_context = class_literal.generic_context(self.db())?; + let collection_alias = collection_class + .try_to_class_literal(self.db())? + .identity_specialization(self.db()) + .into_generic_alias()?; + + let generic_context = collection_alias + .specialization(self.db()) + .generic_context(self.db()); + Some(( - class_literal, + collection_alias, generic_context, generic_context.variables(self.db()), )) }; - let Some((class_literal, generic_context, elt_tys)) = elt_tys(collection_class) else { + let Some((collection_alias, generic_context, elt_tys)) = elt_tys(collection_class) else { // Infer the element types without type context, and fallback to unknown for // custom typesheds. for elt in elts.flatten().flatten() { @@ -7724,41 +7731,40 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { annotation.filter_disjoint_elements(self.db(), collection_ty, inferable) }); - // Extract the annotated type of `T`, if provided. - let annotated_elt_tys = tcx - .known_specialization(self.db(), collection_class) - .map(|specialization| specialization.types(self.db())); + // Collect type constraints from the declared element types. + let elt_tcx_constraints = { + let mut builder = SpecializationBuilder::new( + self.db(), + generic_context.inferable_typevars(self.db()), + ); + + if let Some(tcx) = tcx.annotation { + let collection_instance = + Type::instance(self.db(), ClassType::Generic(collection_alias)); + builder.infer_reverse(tcx, collection_instance).ok()?; + } + + builder.into_type_mappings() + }; // Create a set of constraints to infer a precise type for `T`. let mut builder = SpecializationBuilder::new(self.db(), inferable); - match annotated_elt_tys { - // The annotated type acts as a constraint for `T`. - // - // Note that we infer the annotated type _before_ the elements, to more closely match the - // order of any unions as written in the type annotation. - Some(annotated_elt_tys) => { - for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys.clone(), annotated_elt_tys) { - builder - .infer(Type::TypeVar(elt_ty), *annotated_elt_ty) - .ok()?; - } - } + for elt_ty in elt_tys.clone() { + let elt_tcx = elt_tcx_constraints + // The annotated type acts as a constraint for `T`. + // + // Note that we infer the annotated type _before_ the elements, to more closely match the + // order of any unions as written in the type annotation. + .get(&elt_ty.identity(self.db())) + .copied() + // If a valid type annotation was not provided, avoid restricting the type of the collection + // by unioning the inferred type with `Unknown`. + .unwrap_or(Type::unknown()); - // If a valid type annotation was not provided, avoid restricting the type of the collection - // by unioning the inferred type with `Unknown`. - None => { - for elt_ty in elt_tys.clone() { - builder.infer(Type::TypeVar(elt_ty), Type::unknown()).ok()?; - } - } + builder.infer(Type::TypeVar(elt_ty), elt_tcx).ok()?; } - let elt_tcxs = match annotated_elt_tys { - None => Either::Left(iter::repeat(TypeContext::default())), - Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))), - }; - for elts in elts { // An unpacking expression for a dictionary. if let &[None, Some(value)] = elts.as_slice() { @@ -7781,10 +7787,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // The inferred type of each element acts as an additional constraint on `T`. - for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone()) - { + for (elt, elt_ty) in iter::zip(elts, elt_tys.clone()) { let Some(elt) = elt else { continue }; + let elt_tcx = TypeContext::new( + elt_tcx_constraints + .get(&elt_ty.identity(self.db())) + .copied(), + ); let inferred_elt_ty = infer_elt_expression(self, elt, elt_tcx); // Simplify the inference based on the declared type of the element. @@ -7802,8 +7812,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let class_type = - class_literal.apply_specialization(self.db(), |_| builder.build(generic_context)); + let class_type = collection_alias + .origin(self.db()) + .apply_specialization(self.db(), |_| builder.build(generic_context)); Type::from(class_type).to_instance(self.db()) } @@ -8272,7 +8283,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { call_expression: &ast::ExprCall, tcx: TypeContext<'db>, ) -> Type<'db> { - // TODO: Use the type context for more precise inference. let callable_type = self.infer_maybe_standalone_expression(&call_expression.func, TypeContext::default()); diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index fb53f10ef4..f6e198d591 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -165,7 +165,7 @@ impl<'db> Type<'db> { // This matches the behaviour of other type checkers, and is required for us to // recognise `str` as a subtype of `Container[str]`. structurally_satisfied.or(db, || { - let Some(nominal_instance) = protocol.as_nominal_type() else { + let Some(nominal_instance) = protocol.to_nominal_instance() else { return ConstraintSet::from(false); }; @@ -175,7 +175,7 @@ impl<'db> Type<'db> { // `Q`'s members in a Liskov-incompatible way. let type_to_test = self .as_protocol_instance() - .and_then(ProtocolInstanceType::as_nominal_type) + .and_then(ProtocolInstanceType::to_nominal_instance) .map(Type::NominalInstance) .unwrap_or(self); @@ -650,7 +650,7 @@ impl<'db> ProtocolInstanceType<'db> { /// If this is a synthesized protocol that does not correspond to a class definition /// in source code, return `None`. These are "pure" abstract types, that cannot be /// treated in a nominal way. - pub(super) fn as_nominal_type(self) -> Option> { + pub(super) fn to_nominal_instance(self) -> Option> { match self.inner { Protocol::FromClass(class) => { Some(NominalInstanceType(NominalInstanceInner::NonTuple(*class)))