diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 8a0eb110e5..8d44cd7195 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -130,13 +130,9 @@ type IntList = list[int] m: IntList = [1, 2, 3] reveal_type(m) # revealed: list[int] -# TODO: this should type-check and avoid literal promotion -# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[Literal[1, 2, 3]]`" n: list[typing.Literal[1, 2, 3]] = [1, 2, 3] reveal_type(n) # revealed: list[Literal[1, 2, 3]] -# TODO: this should type-check and avoid literal promotion -# error: [invalid-assignment] "Object of type `list[Unknown | str]` is not assignable to `list[LiteralString]`" o: list[typing.LiteralString] = ["a", "b", "c"] reveal_type(o) # revealed: list[LiteralString] @@ -160,6 +156,81 @@ a: list[str] = [1, 2, 3] b: set[int] = {1, 2, "3"} ``` +## Literal annnotations are respected + +```toml +[environment] +python-version = "3.12" +``` + +```py +from enum import Enum +from typing_extensions import Literal, LiteralString + +a: list[Literal[1]] = [1] +reveal_type(a) # revealed: list[Literal[1]] + +b: list[Literal[True]] = [True] +reveal_type(b) # revealed: list[Literal[True]] + +c: list[Literal["a"]] = ["a"] +reveal_type(c) # revealed: list[Literal["a"]] + +d: list[LiteralString] = ["a", "b", "c"] +reveal_type(d) # revealed: list[LiteralString] + +e: list[list[Literal[1]]] = [[1]] +reveal_type(e) # revealed: list[list[Literal[1]]] + +class Color(Enum): + RED = "red" + +f: dict[list[Literal[1]], list[Literal[Color.RED]]] = {[1]: [Color.RED, Color.RED]} +reveal_type(f) # revealed: dict[list[Literal[1]], list[Literal[Color.RED]]] + +class X[T]: + def __init__(self, value: T): ... + +g: X[Literal[1]] = X(1) +reveal_type(g) # revealed: X[Literal[1]] + +h: X[int] = X(1) +reveal_type(h) # revealed: X[int] + +i: dict[list[X[Literal[1]]], set[Literal[b"a"]]] = {[X(1)]: {b"a"}} +reveal_type(i) # revealed: dict[list[X[Literal[1]]], set[Literal[b"a"]]] + +j: list[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(j) # revealed: list[Literal[1, 2, 3]] + +k: list[Literal[1] | Literal[2] | Literal[3]] = [1, 2, 3] +reveal_type(k) # revealed: list[Literal[1, 2, 3]] + +type Y[T] = list[T] + +l: Y[Y[Literal[1]]] = [[1]] +reveal_type(l) # revealed: list[list[Literal[1]]] + +m: list[tuple[Literal[1], Literal[2], Literal[3]]] = [(1, 2, 3)] +reveal_type(m) # revealed: list[tuple[Literal[1], Literal[2], Literal[3]]] + +n: list[tuple[int, str, int]] = [(1, "2", 3), (4, "5", 6)] +reveal_type(n) # revealed: list[tuple[int, str, int]] + +o: list[tuple[Literal[1], ...]] = [(1, 1, 1)] +reveal_type(o) # revealed: list[tuple[Literal[1], ...]] + +p: list[tuple[int, ...]] = [(1, 1, 1)] +reveal_type(p) # revealed: list[tuple[int, ...]] + +# literal promotion occurs based on assignability, an exact match is not required +q: list[int | Literal[1]] = [1] +reveal_type(q) # revealed: list[int] + +r: list[Literal[1, 2, 3, 4]] = [1, 2] +reveal_type(r) # revealed: list[Literal[1, 2, 3, 4]] +``` + ## PEP-604 annotations are supported ```py diff --git a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md index 1f1d52a57f..07bfa4066b 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md +++ b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md @@ -525,10 +525,6 @@ from typing import Literal reveal_type(list((1, 2, 3))) # revealed: list[int] reveal_type(list(((1, 2, 3),))) # revealed: list[tuple[int, int, int]] -# TODO: we could bidirectionally infer that the user does not want literals to be promoted here, -# and avoid this diagnostic -# -# error: [invalid-assignment] "`list[int]` is not assignable to `list[Literal[1, 2, 3]]`" x: list[Literal[1, 2, 3]] = list((1, 2, 3)) reveal_type(x) # revealed: list[Literal[1, 2, 3]] ``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index c40570ca04..d8f99738ca 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -524,14 +524,15 @@ impl<'db> PropertyInstanceType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { let getter = self .getter(db) - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)); + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)); let setter = self .setter(db) - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)); + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)); Self::new(db, getter, setter) } @@ -825,6 +826,7 @@ impl<'db> Type<'db> { db, dunder_name, &mut CallArguments::positional([Type::unknown()]), + TypeContext::default(), MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK | MemberLookupPolicy::MRO_NO_INT_OR_STR_LOOKUP, ); @@ -856,9 +858,21 @@ impl<'db> Type<'db> { // If the type is a specialized instance of the given `KnownClass`, returns the specialization. pub(crate) fn known_specialization( - self, - known_class: KnownClass, + &self, db: &'db dyn Db, + known_class: KnownClass, + ) -> Option> { + let class_literal = known_class.try_to_class_literal(db)?; + self.specialization_of(db, Some(class_literal)) + } + + // If the type is a specialized instance of the given class, returns the specialization. + // + // If no class is provided, returns the specialization of any class instance. + pub(crate) fn specialization_of( + self, + db: &'db dyn Db, + expected_class: Option>, ) -> Option> { let class_type = match self { Type::NominalInstance(instance) => instance, @@ -867,13 +881,12 @@ impl<'db> Type<'db> { } .class(db); - if !class_type.is_known(db, known_class) { + let (class_literal, specialization) = class_type.class_literal(db); + if expected_class.is_some_and(|expected_class| expected_class != class_literal) { return None; } - class_type - .into_generic_alias() - .map(|generic_alias| generic_alias.specialization(db)) + specialization } /// Returns the top materialization (or upper bound materialization) of this type, which is the @@ -945,7 +958,12 @@ impl<'db> Type<'db> { materialization_kind: MaterializationKind, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Type<'db> { - self.apply_type_mapping_impl(db, &TypeMapping::Materialize(materialization_kind), visitor) + self.apply_type_mapping_impl( + db, + &TypeMapping::Materialize(materialization_kind), + TypeContext::default(), + visitor, + ) } pub(crate) const fn is_type_var(self) -> bool { @@ -1180,22 +1198,35 @@ impl<'db> Type<'db> { /// Note that this function tries to promote literals to a more user-friendly form than their /// fallback instance type. For example, `def _() -> int` is promoted to `Callable[[], int]`, /// as opposed to `FunctionType`. - pub(crate) fn promote_literals(self, db: &'db dyn Db) -> Type<'db> { - self.apply_type_mapping(db, &TypeMapping::PromoteLiterals) + /// + /// It also avoids literal promotion if a literal type annotation was provided as type context. + pub(crate) fn promote_literals(self, db: &'db dyn Db, tcx: TypeContext<'db>) -> Type<'db> { + self.apply_type_mapping(db, &TypeMapping::PromoteLiterals, tcx) } /// Like [`Type::promote_literals`], but does not recurse into nested types. - fn promote_literals_impl(self, db: &'db dyn Db) -> Type<'db> { - match self { - Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_instance(db), + fn promote_literals_impl(self, db: &'db dyn Db, tcx: TypeContext<'db>) -> Type<'db> { + let promoted = match self { + Type::StringLiteral(_) => KnownClass::Str.to_instance(db), + Type::LiteralString => KnownClass::Str.to_instance(db), Type::BooleanLiteral(_) => KnownClass::Bool.to_instance(db), Type::IntLiteral(_) => KnownClass::Int.to_instance(db), Type::BytesLiteral(_) => KnownClass::Bytes.to_instance(db), Type::ModuleLiteral(_) => KnownClass::ModuleType.to_instance(db), Type::EnumLiteral(literal) => literal.enum_class_instance(db), Type::FunctionLiteral(literal) => Type::Callable(literal.into_callable_type(db)), - _ => self, + _ => return self, + }; + + // Avoid literal promotion if it leads to an unassignable type. + if tcx + .annotation + .is_none_or(|annotation| promoted.is_assignable_to(db, annotation)) + { + return promoted; } + + self } /// Return a "normalized" version of `self` that ensures that equivalent types have the same Salsa ID. @@ -3973,6 +4004,7 @@ impl<'db> Type<'db> { db, "__getattr__", CallArguments::positional([Type::string_literal(db, &name)]), + TypeContext::default(), ) .map(|outcome| Place::bound(outcome.return_type(db))) // TODO: Handle call errors here. @@ -3992,6 +4024,7 @@ impl<'db> Type<'db> { db, "__getattribute__", &mut CallArguments::positional([Type::string_literal(db, &name)]), + TypeContext::default(), MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, ) .map(|outcome| Place::bound(outcome.return_type(db))) @@ -4139,7 +4172,12 @@ impl<'db> Type<'db> { // runtime there is a fallback to `__len__`, since `__bool__` takes precedence // and a subclass could add a `__bool__` method. - match self.try_call_dunder(db, "__bool__", CallArguments::none()) { + match self.try_call_dunder( + db, + "__bool__", + CallArguments::none(), + TypeContext::default(), + ) { Ok(outcome) => { let return_type = outcome.return_type(db); if !return_type.is_assignable_to(db, KnownClass::Bool.to_instance(db)) { @@ -4354,7 +4392,12 @@ impl<'db> Type<'db> { return usize_len.try_into().ok().map(Type::IntLiteral); } - let return_ty = match self.try_call_dunder(db, "__len__", CallArguments::none()) { + let return_ty = match self.try_call_dunder( + db, + "__len__", + CallArguments::none(), + TypeContext::default(), + ) { Ok(bindings) => bindings.return_type(db), Err(CallDunderError::PossiblyUnbound(bindings)) => bindings.return_type(db), @@ -5181,11 +5224,13 @@ impl<'db> Type<'db> { db: &'db dyn Db, name: &str, mut argument_types: CallArguments<'_, 'db>, + tcx: TypeContext<'db>, ) -> Result, CallDunderError<'db>> { self.try_call_dunder_with_policy( db, name, &mut argument_types, + tcx, MemberLookupPolicy::default(), ) } @@ -5202,6 +5247,7 @@ impl<'db> Type<'db> { db: &'db dyn Db, name: &str, argument_types: &mut CallArguments<'_, 'db>, + tcx: TypeContext<'db>, policy: MemberLookupPolicy, ) -> Result, CallDunderError<'db>> { // Implicit calls to dunder methods never access instance members, so we pass @@ -5218,7 +5264,7 @@ impl<'db> Type<'db> { let bindings = dunder_callable .bindings(db) .match_parameters(db, argument_types) - .check_types(db, argument_types, &TypeContext::default())?; + .check_types(db, argument_types, &tcx)?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } @@ -5266,11 +5312,21 @@ impl<'db> Type<'db> { CallDunderError<'db>, > { iterator - .try_call_dunder(db, "__anext__", CallArguments::none()) + .try_call_dunder( + db, + "__anext__", + CallArguments::none(), + TypeContext::default(), + ) .map(|dunder_anext_outcome| dunder_anext_outcome.return_type(db).try_await(db)) }; - return match self.try_call_dunder(db, "__aiter__", CallArguments::none()) { + return match self.try_call_dunder( + db, + "__aiter__", + CallArguments::none(), + TypeContext::default(), + ) { Ok(dunder_aiter_bindings) => { let iterator = dunder_aiter_bindings.return_type(db); match try_call_dunder_anext_on_iterator(iterator) { @@ -5414,18 +5470,29 @@ impl<'db> Type<'db> { db, "__getitem__", CallArguments::positional([KnownClass::Int.to_instance(db)]), + TypeContext::default(), ) .map(|dunder_getitem_outcome| dunder_getitem_outcome.return_type(db)) }; let try_call_dunder_next_on_iterator = |iterator: Type<'db>| { iterator - .try_call_dunder(db, "__next__", CallArguments::none()) + .try_call_dunder( + db, + "__next__", + CallArguments::none(), + TypeContext::default(), + ) .map(|dunder_next_outcome| dunder_next_outcome.return_type(db)) }; let dunder_iter_result = self - .try_call_dunder(db, "__iter__", CallArguments::none()) + .try_call_dunder( + db, + "__iter__", + CallArguments::none(), + TypeContext::default(), + ) .map(|dunder_iter_outcome| dunder_iter_outcome.return_type(db)); match dunder_iter_result { @@ -5533,11 +5600,17 @@ impl<'db> Type<'db> { EvaluationMode::Sync => ("__enter__", "__exit__"), }; - let enter = self.try_call_dunder(db, enter_method, CallArguments::none()); + let enter = self.try_call_dunder( + db, + enter_method, + CallArguments::none(), + TypeContext::default(), + ); let exit = self.try_call_dunder( db, exit_method, CallArguments::positional([Type::none(db), Type::none(db), Type::none(db)]), + TypeContext::default(), ); // TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`). @@ -5574,7 +5647,12 @@ impl<'db> Type<'db> { /// Resolve the type of an `await …` expression where `self` is the type of the awaitable. fn try_await(self, db: &'db dyn Db) -> Result, AwaitError<'db>> { - let await_result = self.try_call_dunder(db, "__await__", CallArguments::none()); + let await_result = self.try_call_dunder( + db, + "__await__", + CallArguments::none(), + TypeContext::default(), + ); match await_result { Ok(bindings) => { let return_type = bindings.return_type(db); @@ -5642,6 +5720,7 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, argument_types: CallArguments<'_, 'db>, + tcx: TypeContext<'db>, ) -> Result, ConstructorCallError<'db>> { debug_assert!(matches!( self, @@ -5729,7 +5808,7 @@ impl<'db> Type<'db> { .place .is_unbound() { - Some(init_ty.try_call_dunder(db, "__init__", argument_types)) + Some(init_ty.try_call_dunder(db, "__init__", argument_types, tcx)) } else { None }; @@ -6274,8 +6353,11 @@ impl<'db> Type<'db> { db: &'db dyn Db, specialization: Specialization<'db>, ) -> Type<'db> { - let new_specialization = - self.apply_type_mapping(db, &TypeMapping::Specialization(specialization)); + let new_specialization = self.apply_type_mapping( + db, + &TypeMapping::Specialization(specialization), + TypeContext::default(), + ); match specialization.materialization_kind(db) { None => new_specialization, Some(materialization_kind) => new_specialization.materialize( @@ -6290,14 +6372,16 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, ) -> Type<'db> { - self.apply_type_mapping_impl(db, type_mapping, &ApplyTypeMappingVisitor::default()) + self.apply_type_mapping_impl(db, type_mapping, tcx, &ApplyTypeMappingVisitor::default()) } fn apply_type_mapping_impl<'a>( self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Type<'db> { match self { @@ -6373,22 +6457,23 @@ impl<'db> Type<'db> { } Type::FunctionLiteral(function) => { - let function = Type::FunctionLiteral(function.apply_type_mapping_impl(db, type_mapping, visitor)); + let function = Type::FunctionLiteral(function.apply_type_mapping_impl(db, type_mapping, tcx, visitor)); match type_mapping { - TypeMapping::PromoteLiterals => function.promote_literals_impl(db), + TypeMapping::PromoteLiterals => function.promote_literals_impl(db, tcx), _ => function } } Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new( db, - method.function(db).apply_type_mapping_impl(db, type_mapping, visitor), - method.self_instance(db).apply_type_mapping_impl(db, type_mapping, visitor), + method.function(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor), + method.self_instance(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor), )), - Type::NominalInstance(instance) => - instance.apply_type_mapping_impl(db, type_mapping, visitor), + Type::NominalInstance(instance) => { + instance.apply_type_mapping_impl(db, type_mapping, tcx, visitor) + }, Type::ProtocolInstance(instance) => { // TODO: Add tests for materialization once subtyping/assignability is implemented for @@ -6398,59 +6483,59 @@ impl<'db> Type<'db> { // > read-only property members, and method members, on protocols act covariantly; // > write-only property members act contravariantly; and read/write attribute // > members on protocols act invariantly - Type::ProtocolInstance(instance.apply_type_mapping_impl(db, type_mapping, visitor)) + Type::ProtocolInstance(instance.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderGet(function)) => { Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderGet( - function.apply_type_mapping_impl(db, type_mapping, visitor), + function.apply_type_mapping_impl(db, type_mapping, tcx, visitor), )) } Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall(function)) => { Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall( - function.apply_type_mapping_impl(db, type_mapping, visitor), + function.apply_type_mapping_impl(db, type_mapping, tcx, visitor), )) } Type::KnownBoundMethod(KnownBoundMethodType::PropertyDunderGet(property)) => { Type::KnownBoundMethod(KnownBoundMethodType::PropertyDunderGet( - property.apply_type_mapping_impl(db, type_mapping, visitor), + property.apply_type_mapping_impl(db, type_mapping, tcx, visitor), )) } Type::KnownBoundMethod(KnownBoundMethodType::PropertyDunderSet(property)) => { Type::KnownBoundMethod(KnownBoundMethodType::PropertyDunderSet( - property.apply_type_mapping_impl(db, type_mapping, visitor), + property.apply_type_mapping_impl(db, type_mapping, tcx, visitor), )) } Type::Callable(callable) => { - Type::Callable(callable.apply_type_mapping_impl(db, type_mapping, visitor)) + Type::Callable(callable.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } Type::GenericAlias(generic) => { - Type::GenericAlias(generic.apply_type_mapping_impl(db, type_mapping, visitor)) + Type::GenericAlias(generic.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } Type::TypedDict(typed_dict) => { - Type::TypedDict(typed_dict.apply_type_mapping_impl(db, type_mapping, visitor)) + Type::TypedDict(typed_dict.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } - Type::SubclassOf(subclass_of) => subclass_of.apply_type_mapping_impl(db, type_mapping, visitor), + Type::SubclassOf(subclass_of) => subclass_of.apply_type_mapping_impl(db, type_mapping, tcx, visitor), Type::PropertyInstance(property) => { - Type::PropertyInstance(property.apply_type_mapping_impl(db, type_mapping, visitor)) + Type::PropertyInstance(property.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } Type::Union(union) => union.map(db, |element| { - element.apply_type_mapping_impl(db, type_mapping, visitor) + element.apply_type_mapping_impl(db, type_mapping, tcx, visitor) }), Type::Intersection(intersection) => { let mut builder = IntersectionBuilder::new(db); for positive in intersection.positive(db) { builder = - builder.add_positive(positive.apply_type_mapping_impl(db, type_mapping, visitor)); + builder.add_positive(positive.apply_type_mapping_impl(db, type_mapping, tcx, visitor)); } let flipped_mapping = match type_mapping { TypeMapping::Materialize(materialization_kind) => &TypeMapping::Materialize(materialization_kind.flip()), @@ -6458,16 +6543,16 @@ impl<'db> Type<'db> { }; for negative in intersection.negative(db) { builder = - builder.add_negative(negative.apply_type_mapping_impl(db, flipped_mapping, visitor)); + builder.add_negative(negative.apply_type_mapping_impl(db, flipped_mapping, tcx, visitor)); } builder.build() } // TODO(jelle): Materialize should be handled differently, since TypeIs is invariant - Type::TypeIs(type_is) => type_is.with_type(db, type_is.return_type(db).apply_type_mapping(db, type_mapping)), + Type::TypeIs(type_is) => type_is.with_type(db, type_is.return_type(db).apply_type_mapping(db, type_mapping, tcx)), Type::TypeAlias(alias) => { - visitor.visit(self, || alias.value_type(db).apply_type_mapping_impl(db, type_mapping, visitor)) + visitor.visit(self, || alias.value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } Type::ModuleLiteral(_) @@ -6484,7 +6569,7 @@ impl<'db> Type<'db> { TypeMapping::ReplaceSelf { .. } | TypeMapping::MarkTypeVarsInferable(_) | TypeMapping::Materialize(_) => self, - TypeMapping::PromoteLiterals => self.promote_literals_impl(db) + TypeMapping::PromoteLiterals => self.promote_literals_impl(db, tcx) } Type::Dynamic(_) => match type_mapping { @@ -7931,12 +8016,14 @@ impl<'db> TypeVarInstance<'db> { }), self.explicit_variance(db), self._default(db).and_then(|default| match default { - TypeVarDefaultEvaluation::Eager(ty) => { - Some(ty.apply_type_mapping_impl(db, type_mapping, visitor).into()) - } - TypeVarDefaultEvaluation::Lazy => self - .lazy_default(db) - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor).into()), + TypeVarDefaultEvaluation::Eager(ty) => Some( + ty.apply_type_mapping_impl(db, type_mapping, TypeContext::default(), visitor) + .into(), + ), + TypeVarDefaultEvaluation::Lazy => self.lazy_default(db).map(|ty| { + ty.apply_type_mapping_impl(db, type_mapping, TypeContext::default(), visitor) + .into() + }), }), self.kind(db), ) @@ -8151,9 +8238,13 @@ impl<'db> BoundTypeVarInstance<'db> { /// (resulting in `T@C`). pub(crate) fn default_type(self, db: &'db dyn Db) -> Option> { let binding_context = self.binding_context(db); - self.typevar(db) - .default_type(db) - .map(|ty| ty.apply_type_mapping(db, &TypeMapping::BindLegacyTypevars(binding_context))) + self.typevar(db).default_type(db).map(|ty| { + ty.apply_type_mapping( + db, + &TypeMapping::BindLegacyTypevars(binding_context), + TypeContext::default(), + ) + }) } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { @@ -8305,7 +8396,7 @@ impl<'db> TypeVarBoundOrConstraints<'db> { match self { TypeVarBoundOrConstraints::UpperBound(bound) => TypeVarBoundOrConstraints::UpperBound( - bound.apply_type_mapping_impl(db, type_mapping, visitor), + bound.apply_type_mapping_impl(db, type_mapping, TypeContext::default(), visitor), ), TypeVarBoundOrConstraints::Constraints(constraints) => { TypeVarBoundOrConstraints::Constraints(UnionType::new( @@ -8313,7 +8404,14 @@ impl<'db> TypeVarBoundOrConstraints<'db> { constraints .elements(db) .iter() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)) + .map(|ty| { + ty.apply_type_mapping_impl( + db, + type_mapping, + TypeContext::default(), + visitor, + ) + }) .collect::>() .into_boxed_slice(), )) @@ -8553,12 +8651,17 @@ impl<'db> ContextManagerError<'db> { EvaluationMode::Async => ("sync", "__enter__", "__exit__", "with"), }; - let alt_enter = - context_expression_type.try_call_dunder(db, alt_enter_method, CallArguments::none()); + let alt_enter = context_expression_type.try_call_dunder( + db, + alt_enter_method, + CallArguments::none(), + TypeContext::default(), + ); let alt_exit = context_expression_type.try_call_dunder( db, alt_exit_method, CallArguments::positional([Type::unknown(), Type::unknown(), Type::unknown()]), + TypeContext::default(), ); if (alt_enter.is_ok() || matches!(alt_enter, Err(CallDunderError::CallError(..)))) @@ -8654,6 +8757,7 @@ impl<'db> IterationError<'db> { db, "__anext__", CallArguments::none(), + TypeContext::default(), )) .and_then(|ty| ty.try_await(db).ok()) } else { @@ -8661,6 +8765,7 @@ impl<'db> IterationError<'db> { db, "__next__", CallArguments::none(), + TypeContext::default(), )) } } @@ -9704,12 +9809,13 @@ impl<'db> CallableType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { CallableType::new( db, self.signatures(db) - .apply_type_mapping_impl(db, type_mapping, visitor), + .apply_type_mapping_impl(db, type_mapping, tcx, visitor), self.is_function_like(db), ) } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 6ea7fd8752..7854982a73 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2563,7 +2563,10 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } } - self.specialization = self.signature.generic_context.map(|gc| builder.build(gc)); + self.specialization = self + .signature + .generic_context + .map(|gc| builder.build(gc, *self.call_expression_tcx)); } fn check_argument_type( diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 93a8d8fcf6..2b580e9c29 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -280,13 +280,20 @@ impl<'db> GenericAlias<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { + let tcx = tcx + .annotation + .and_then(|ty| ty.specialization_of(db, Some(self.origin(db)))) + .map(|specialization| specialization.types(db)) + .unwrap_or(&[]); + Self::new( db, self.origin(db), self.specialization(db) - .apply_type_mapping_impl(db, type_mapping, visitor), + .apply_type_mapping_impl(db, type_mapping, tcx, visitor), ) } @@ -469,12 +476,13 @@ impl<'db> ClassType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { match self { Self::NonGeneric(_) => self, Self::Generic(generic) => { - Self::Generic(generic.apply_type_mapping_impl(db, type_mapping, visitor)) + Self::Generic(generic.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } } } @@ -2030,6 +2038,7 @@ impl<'db> ClassLiteral<'db> { ClassBase::is_typed_dict, ), }, + TypeContext::default(), ) }); } @@ -2337,6 +2346,7 @@ impl<'db> ClassLiteral<'db> { }, ), }, + TypeContext::default(), ) }) } @@ -2671,7 +2681,8 @@ impl<'db> ClassLiteral<'db> { specialization, ClassBase::is_typed_dict ) - } + }, + TypeContext::default(), ) ) } @@ -2921,6 +2932,7 @@ impl<'db> ClassLiteral<'db> { self.unknown_specialization(db), ), }, + TypeContext::default(), ) }); } diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 547ace5923..204cbfb350 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -5,7 +5,7 @@ use crate::types::tuple::TupleType; use crate::types::{ ApplyTypeMappingVisitor, ClassLiteral, ClassType, DynamicType, KnownClass, KnownInstanceType, MaterializationKind, MroError, MroIterator, NormalizedVisitor, SpecialFormType, Type, - TypeMapping, todo_type, + TypeContext, TypeMapping, todo_type, }; /// Enumeration of the possible kinds of types we allow in class bases. @@ -277,11 +277,12 @@ impl<'db> ClassBase<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { match self { Self::Class(class) => { - Self::Class(class.apply_type_mapping_impl(db, type_mapping, visitor)) + Self::Class(class.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } Self::Dynamic(_) | Self::Generic | Self::Protocol | Self::TypedDict => self, } @@ -296,6 +297,7 @@ impl<'db> ClassBase<'db> { let new_self = self.apply_type_mapping_impl( db, &TypeMapping::Specialization(specialization), + TypeContext::default(), &ApplyTypeMappingVisitor::default(), ); match specialization.materialization_kind(db) { @@ -311,6 +313,7 @@ impl<'db> ClassBase<'db> { self.apply_type_mapping_impl( db, &TypeMapping::Materialize(kind), + TypeContext::default(), &ApplyTypeMappingVisitor::default(), ) } diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index be59be8f41..1b38ecd1f3 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -19,7 +19,7 @@ use crate::types::string_annotation::{ }; use crate::types::{ ClassType, DynamicType, LintDiagnosticGuard, Protocol, ProtocolInstanceType, SubclassOfInner, - binding_type, infer_isolated_expression, + TypeContext, binding_type, infer_isolated_expression, }; use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClass}; use crate::util::diagnostics::format_enumeration; @@ -2719,7 +2719,7 @@ pub(crate) fn report_undeclared_protocol_member( if definition.kind(db).is_unannotated_assignment() { let binding_type = binding_type(db, definition); - let suggestion = binding_type.promote_literals(db); + let suggestion = binding_type.promote_literals(db, TypeContext::default()); if should_give_hint(db, suggestion) { diagnostic.set_primary_message(format_args!( @@ -2826,6 +2826,7 @@ pub(crate) fn report_invalid_or_unsupported_base( db, "__mro_entries__", CallArguments::positional([tuple_of_types]), + TypeContext::default(), ) { Ok(ret) => { if ret.return_type(db).is_assignable_to(db, tuple_of_types) { diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index bb357d4afa..2f6b5858d7 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -82,8 +82,8 @@ use crate::types::{ ApplyTypeMappingVisitor, BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, DeprecatedInstance, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, - SpecialFormType, TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation, - UnionBuilder, binding_type, todo_type, walk_signature, + SpecialFormType, TrackedConstraintSet, Truthiness, Type, TypeContext, TypeMapping, + TypeRelation, UnionBuilder, binding_type, todo_type, walk_signature, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; @@ -667,14 +667,15 @@ impl<'db> FunctionType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { let updated_signature = self.signature(db) - .apply_type_mapping_impl(db, type_mapping, visitor); + .apply_type_mapping_impl(db, type_mapping, tcx, visitor); let updated_last_definition_signature = self .last_definition_signature(db) - .apply_type_mapping_impl(db, type_mapping, visitor); + .apply_type_mapping_impl(db, type_mapping, tcx, visitor); Self::new( db, self.literal(db), diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 2431eba4c0..f84cf8b4b1 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -16,7 +16,7 @@ use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, - MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, + MaterializationKind, NormalizedVisitor, Type, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type, }; @@ -460,8 +460,11 @@ impl<'db> GenericContext<'db> { generic_context: self, types: &expanded[0..idx], }; - let default = - default.apply_type_mapping(db, &TypeMapping::PartialSpecialization(partial)); + let default = default.apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); expanded[idx] = default; } @@ -791,27 +794,34 @@ impl<'db> Specialization<'db> { db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, ) -> Self { - self.apply_type_mapping_impl(db, type_mapping, &ApplyTypeMappingVisitor::default()) + self.apply_type_mapping_impl(db, type_mapping, &[], &ApplyTypeMappingVisitor::default()) } pub(crate) fn apply_type_mapping_impl<'a>( self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: &[Type<'db>], visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { if let TypeMapping::Materialize(materialization_kind) = type_mapping { return self.materialize_impl(db, *materialization_kind, visitor); } + let types: Box<[_]> = self .types(db) .iter() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)) + .enumerate() + .map(|(i, ty)| { + let tcx = TypeContext::new(tcx.get(i).copied()); + ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor) + }) .collect(); - let tuple_inner = self - .tuple_inner(db) - .and_then(|tuple| tuple.apply_type_mapping_impl(db, type_mapping, visitor)); + let tuple_inner = self.tuple_inner(db).and_then(|tuple| { + tuple.apply_type_mapping_impl(db, type_mapping, TypeContext::default(), visitor) + }); + Specialization::new( db, self.generic_context(db), @@ -924,6 +934,7 @@ impl<'db> Specialization<'db> { tuple.apply_type_mapping_impl( db, &TypeMapping::Materialize(materialization_kind), + TypeContext::default(), visitor, ) }); @@ -1122,19 +1133,30 @@ impl<'db> SpecializationBuilder<'db> { } } - pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> { + pub(crate) fn build( + &mut self, + generic_context: GenericContext<'db>, + tcx: TypeContext<'db>, + ) -> Specialization<'db> { + let tcx_specialization = tcx + .annotation + .and_then(|annotation| annotation.specialization_of(self.db, None)); + let types = (generic_context.variables_inner(self.db).iter()).map(|(variable, options)| { let mut ty = self.types.get(variable).copied(); // When inferring a specialization for a generic class typevar from a constructor call, - // promote any typevars that are inferred as a literal to the corresponding instance - // type. + // promote any typevars that are inferred as a literal to the corresponding instance type. if options.should_promote_literals { - ty = ty.map(|ty| ty.promote_literals(self.db)); + let tcx = tcx_specialization + .and_then(|specialization| specialization.get(self.db, *variable)); + + ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx))); } ty }); + // TODO Infer the tuple spec for a tuple type generic_context.specialize_partial(self.db, types) } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 22aeaba771..9c0d38a601 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -383,11 +383,11 @@ impl<'db> TypeContext<'db> { // specialization. fn known_specialization( &self, - known_class: KnownClass, db: &'db dyn Db, + known_class: KnownClass, ) -> Option> { self.annotation - .and_then(|ty| ty.known_specialization(known_class, db)) + .and_then(|ty| ty.known_specialization(db, known_class)) } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 9239718093..5f4ae38a81 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -3261,6 +3261,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { db, "__setitem__", CallArguments::positional([slice_ty, assigned_ty]), + TypeContext::default(), ) { Ok(_) => true, Err(err) => match err { @@ -3533,6 +3534,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { db, "__setattr__", &mut CallArguments::positional([Type::string_literal(db, attribute), value_ty]), + TypeContext::default(), MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, ); @@ -4239,6 +4241,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { db, op.in_place_dunder(), CallArguments::positional([value_type]), + TypeContext::default(), ); match call { @@ -5451,7 +5454,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = tuple; let annotated_tuple = tcx - .known_specialization(KnownClass::Tuple, self.db()) + .known_specialization(self.db(), KnownClass::Tuple) .and_then(|specialization| { specialization .tuple(self.db()) @@ -5586,14 +5589,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Some((class_literal, generic_context.variables(self.db()))) }; - let (class_literal, elt_tys) = elt_tys(collection_class).unwrap_or_else(|| { - let name = collection_class.name(self.db()); - panic!("Typeshed should always have a `{name}` class in `builtins.pyi`") - }); + let Some((class_literal, elt_tys)) = elt_tys(collection_class) else { + // Infer the element types without type context, and fallback to unknown for + // custom typesheds. + for elt in elts.flatten().flatten() { + self.get_or_infer_expression(elt, TypeContext::default()); + } + + return None; + }; // Extract the annotated type of `T`, if provided. let annotated_elt_tys = tcx - .known_specialization(collection_class, self.db()) + .known_specialization(self.db(), collection_class) .map(|specialization| specialization.types(self.db())); // Create a set of constraints to infer a precise type for `T`. @@ -5633,7 +5641,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Merge the inferred type of the nested dictionary. if let Some(specialization) = - inferred_value_ty.known_specialization(KnownClass::Dict, self.db()) + inferred_value_ty.known_specialization(self.db(), KnownClass::Dict) { for (elt_ty, inferred_elt_ty) in iter::zip(elt_tys.clone(), specialization.types(self.db())) @@ -5656,14 +5664,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Convert any element literals to their promoted type form to avoid excessively large // unions for large nested list literals, which the constraint solver struggles with. - let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db()); + let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db(), elt_tcx); builder.infer(Type::TypeVar(elt_ty), inferred_elt_ty).ok()?; } } - let class_type = class_literal - .apply_specialization(self.db(), |generic_context| builder.build(generic_context)); + let class_type = class_literal.apply_specialization(self.db(), |generic_context| { + builder.build(generic_context, TypeContext::default()) + }); Type::from(class_type).to_instance(self.db()) } @@ -6186,7 +6195,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); return callable_type - .try_call_constructor(self.db(), call_arguments) + .try_call_constructor(self.db(), call_arguments, tcx) .unwrap_or_else(|err| { err.report_diagnostic(&self.context, callable_type, call_expression.into()); err.return_type() @@ -7225,6 +7234,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.db(), unary_dunder_method, CallArguments::none(), + TypeContext::default(), ) { Ok(outcome) => outcome.return_type(self.db()), Err(e) => { @@ -7644,6 +7654,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.db(), reflected_dunder, CallArguments::positional([left_ty]), + TypeContext::default(), ) .map(|outcome| outcome.return_type(self.db())) .or_else(|_| { @@ -7652,6 +7663,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.db(), op.dunder(), CallArguments::positional([right_ty]), + TypeContext::default(), ) .map(|outcome| outcome.return_type(self.db())) }) @@ -7664,6 +7676,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.db(), op.dunder(), CallArguments::positional([right_ty]), + TypeContext::default(), ) .map(|outcome| outcome.return_type(self.db())) .ok(); @@ -7677,6 +7690,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.db(), op.reflected_dunder(), CallArguments::positional([left_ty]), + TypeContext::default(), ) .map(|outcome| outcome.return_type(self.db())) .ok() @@ -8430,6 +8444,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { db, op.dunder(), &mut CallArguments::positional([right]), + TypeContext::default(), policy, ) .map(|outcome| outcome.return_type(db)) @@ -9025,7 +9040,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // If the class defines `__getitem__`, return its return type. // // See: https://docs.python.org/3/reference/datamodel.html#class-getitem-versus-getitem - match value_ty.try_call_dunder(db, "__getitem__", CallArguments::positional([slice_ty])) { + match value_ty.try_call_dunder( + db, + "__getitem__", + CallArguments::positional([slice_ty]), + TypeContext::default(), + ) { Ok(outcome) => { return outcome.return_type(db); } diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index e02d12decc..7d7998530e 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -14,8 +14,8 @@ use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, ClassBase, ClassLiteral, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, TypeMapping, - TypeRelation, VarianceInferable, + HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, TypeContext, + TypeMapping, TypeRelation, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -475,15 +475,16 @@ impl<'db> NominalInstanceType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Type<'db> { match self.0 { NominalInstanceInner::ExactTuple(tuple) => { - Type::tuple(tuple.apply_type_mapping_impl(db, type_mapping, visitor)) + Type::tuple(tuple.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } NominalInstanceInner::NonTuple(class) => Type::non_tuple_instance( db, - class.apply_type_mapping_impl(db, type_mapping, visitor), + class.apply_type_mapping_impl(db, type_mapping, tcx, visitor), ), NominalInstanceInner::Object => Type::object(), } @@ -734,15 +735,16 @@ impl<'db> ProtocolInstanceType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { match self.inner { Protocol::FromClass(class) => { - Self::from_class(class.apply_type_mapping_impl(db, type_mapping, visitor)) - } - Protocol::Synthesized(synthesized) => { - Self::synthesized(synthesized.apply_type_mapping_impl(db, type_mapping, visitor)) + Self::from_class(class.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } + Protocol::Synthesized(synthesized) => Self::synthesized( + synthesized.apply_type_mapping_impl(db, type_mapping, tcx, visitor), + ), } } @@ -813,7 +815,7 @@ mod synthesized_protocol { use crate::types::protocol_class::ProtocolInterface; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, - NormalizedVisitor, TypeMapping, TypeVarVariance, VarianceInferable, + NormalizedVisitor, TypeContext, TypeMapping, TypeVarVariance, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -844,9 +846,10 @@ mod synthesized_protocol { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, _visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { - Self(self.0.specialized_and_normalized(db, type_mapping)) + Self(self.0.specialized_and_normalized(db, type_mapping, tcx)) } pub(super) fn find_legacy_typevars_impl( diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 57e8275120..529824b1ea 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -6,6 +6,7 @@ use itertools::Itertools; use ruff_python_ast::name::Name; use rustc_hash::FxHashMap; +use crate::types::TypeContext; use crate::{ Db, FxOrderSet, place::{Boundness, Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations}, @@ -339,6 +340,7 @@ impl<'db> ProtocolInterface<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, ) -> Self { Self::new( db, @@ -350,6 +352,7 @@ impl<'db> ProtocolInterface<'db> { data.apply_type_mapping_impl( db, type_mapping, + tcx, &ApplyTypeMappingVisitor::default(), ) .normalized(db), @@ -428,10 +431,13 @@ impl<'db> ProtocolMemberData<'db> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { Self { - kind: self.kind.apply_type_mapping_impl(db, type_mapping, visitor), + kind: self + .kind + .apply_type_mapping_impl(db, type_mapping, tcx, visitor), qualifiers: self.qualifiers, } } @@ -516,18 +522,22 @@ impl<'db> ProtocolMemberKind<'db> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { match self { ProtocolMemberKind::Method(callable) => ProtocolMemberKind::Method( - callable.apply_type_mapping_impl(db, type_mapping, visitor), + callable.apply_type_mapping_impl(db, type_mapping, tcx, visitor), ), ProtocolMemberKind::Property(property) => ProtocolMemberKind::Property( - property.apply_type_mapping_impl(db, type_mapping, visitor), + property.apply_type_mapping_impl(db, type_mapping, tcx, visitor), ), - ProtocolMemberKind::Other(ty) => { - ProtocolMemberKind::Other(ty.apply_type_mapping_impl(db, type_mapping, visitor)) - } + ProtocolMemberKind::Other(ty) => ProtocolMemberKind::Other(ty.apply_type_mapping_impl( + db, + type_mapping, + tcx, + visitor, + )), } } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index f97dfa4fe2..41b955f4ec 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -28,7 +28,7 @@ use crate::types::infer::nearest_enclosing_class; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, - NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, todo_type, + NormalizedVisitor, TypeContext, TypeMapping, TypeRelation, VarianceInferable, todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -138,12 +138,13 @@ impl<'db> CallableSignature<'db> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { Self::from_overloads( self.overloads .iter() - .map(|signature| signature.apply_type_mapping_impl(db, type_mapping, visitor)), + .map(|signature| signature.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), ) } @@ -435,6 +436,7 @@ impl<'db> Signature<'db> { .apply_type_mapping( db, &TypeMapping::MarkTypeVarsInferable(Some(definition.into())), + TypeContext::default(), ); if function_node.is_async && !is_generator { KnownClass::CoroutineType @@ -523,6 +525,7 @@ impl<'db> Signature<'db> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { let flipped_mapping = match type_mapping { @@ -538,10 +541,10 @@ impl<'db> Signature<'db> { definition: self.definition, parameters: self .parameters - .apply_type_mapping_impl(db, flipped_mapping, visitor), + .apply_type_mapping_impl(db, flipped_mapping, tcx, visitor), return_ty: self .return_ty - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), } } @@ -582,10 +585,16 @@ impl<'db> Signature<'db> { parameters = parameters.apply_type_mapping_impl( db, &TypeMapping::BindSelf(self_type), + TypeContext::default(), &ApplyTypeMappingVisitor::default(), ); - return_ty = - return_ty.map(|ty| ty.apply_type_mapping(db, &TypeMapping::BindSelf(self_type))); + return_ty = return_ty.map(|ty| { + ty.apply_type_mapping( + db, + &TypeMapping::BindSelf(self_type), + TypeContext::default(), + ) + }); } Self { generic_context: self.generic_context, @@ -1294,6 +1303,7 @@ impl<'db> Parameters<'db> { .expect("We should always find the surrounding class for an implicit self: Self annotation").apply_type_mapping( db, &TypeMapping::MarkTypeVarsInferable(None), + TypeContext::default() ) ) } else { @@ -1423,6 +1433,7 @@ impl<'db> Parameters<'db> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { match type_mapping { @@ -1442,7 +1453,7 @@ impl<'db> Parameters<'db> { value: self .value .iter() - .map(|param| param.apply_type_mapping_impl(db, type_mapping, visitor)) + .map(|param| param.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) .collect(), is_gradual: self.is_gradual, }, @@ -1634,13 +1645,16 @@ impl<'db> Parameter<'db> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { Self { annotated_type: self .annotated_type - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), - kind: self.kind.apply_type_mapping_impl(db, type_mapping, visitor), + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), + kind: self + .kind + .apply_type_mapping_impl(db, type_mapping, tcx, visitor), inferred_annotation: self.inferred_annotation, form: self.form, } @@ -1718,6 +1732,7 @@ impl<'db> Parameter<'db> { definition_expression_type(db, definition, annotation).apply_type_mapping( db, &TypeMapping::MarkTypeVarsInferable(Some(definition.into())), + TypeContext::default(), ) }), kind, @@ -1857,25 +1872,26 @@ impl<'db> ParameterKind<'db> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { match self { Self::PositionalOnly { default_type, name } => Self::PositionalOnly { default_type: default_type .as_ref() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), name: name.clone(), }, Self::PositionalOrKeyword { default_type, name } => Self::PositionalOrKeyword { default_type: default_type .as_ref() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), name: name.clone(), }, Self::KeywordOnly { default_type, name } => Self::KeywordOnly { default_type: default_type .as_ref() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), name: name.clone(), }, Self::Variadic { .. } | Self::KeywordVariadic { .. } => self.clone(), diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index fdb450cca1..c6a16620a9 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -5,8 +5,8 @@ use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, KnownClass, - MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, TypeMapping, - TypeRelation, + MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, TypeContext, + TypeMapping, TypeRelation, }; use crate::{Db, FxOrderSet}; @@ -84,6 +84,7 @@ impl<'db> SubclassOfType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Type<'db> { match self.subclass_of { @@ -91,6 +92,7 @@ impl<'db> SubclassOfType<'db> { subclass_of: SubclassOfInner::Class(class.apply_type_mapping_impl( db, type_mapping, + tcx, visitor, )), }), diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index a6c45374f1..a94a0610e7 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -22,7 +22,6 @@ use std::hash::Hash; use itertools::{Either, EitherOrBoth, Itertools}; use crate::semantic_index::definition::Definition; -use crate::types::Truthiness; use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::{ @@ -30,6 +29,7 @@ use crate::types::{ IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation, UnionBuilder, UnionType, }; +use crate::types::{Truthiness, TypeContext}; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; use crate::{Db, FxOrderSet, Program}; @@ -232,13 +232,14 @@ impl<'db> TupleType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Option { TupleType::new( db, &self .tuple(db) - .apply_type_mapping_impl(db, type_mapping, visitor), + .apply_type_mapping_impl(db, type_mapping, tcx, visitor), ) } @@ -396,12 +397,32 @@ impl<'db> FixedLengthTuple> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { + let tcx_tuple = tcx + .annotation + .and_then(|annotation| annotation.known_specialization(db, KnownClass::Tuple)) + .and_then(|specialization| { + specialization + .tuple(db) + .expect("the specialization of `KnownClass::Tuple` must have a tuple spec") + .resize(db, TupleLength::Fixed(self.0.len())) + .ok() + }); + + let tcx_elements = match tcx_tuple.as_ref() { + None => Either::Right(std::iter::repeat(TypeContext::default())), + Some(tuple) => { + Either::Left(tuple.all_elements().map(|tcx| TypeContext::new(Some(*tcx)))) + } + }; + Self::from_elements( self.0 .iter() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), + .zip(tcx_elements) + .map(|(ty, tcx)| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), ) } @@ -736,17 +757,18 @@ impl<'db> VariableLengthTuple> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> TupleSpec<'db> { Self::mixed( self.prefix .iter() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), self.variable - .apply_type_mapping_impl(db, type_mapping, visitor), + .apply_type_mapping_impl(db, type_mapping, tcx, visitor), self.suffix .iter() - .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor)), + .map(|ty| ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)), ) } @@ -1116,13 +1138,14 @@ impl<'db> Tuple> { &self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { match self { Tuple::Fixed(tuple) => { - Tuple::Fixed(tuple.apply_type_mapping_impl(db, type_mapping, visitor)) + Tuple::Fixed(tuple.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } - Tuple::Variable(tuple) => tuple.apply_type_mapping_impl(db, type_mapping, visitor), + Tuple::Variable(tuple) => tuple.apply_type_mapping_impl(db, type_mapping, tcx, visitor), } } diff --git a/crates/ty_python_semantic/src/types/typed_dict.rs b/crates/ty_python_semantic/src/types/typed_dict.rs index 833c3bdded..424ec1e6da 100644 --- a/crates/ty_python_semantic/src/types/typed_dict.rs +++ b/crates/ty_python_semantic/src/types/typed_dict.rs @@ -12,6 +12,7 @@ use super::diagnostic::{ report_missing_typed_dict_key, }; use super::{ApplyTypeMappingVisitor, Type, TypeMapping, visitor}; +use crate::types::TypeContext; use crate::{Db, FxOrderMap}; use ordermap::OrderSet; @@ -62,13 +63,17 @@ impl<'db> TypedDictType<'db> { self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>, + tcx: TypeContext<'db>, visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { // TODO: Materialization of gradual TypedDicts needs more logic Self { - defining_class: self - .defining_class - .apply_type_mapping_impl(db, type_mapping, visitor), + defining_class: self.defining_class.apply_type_mapping_impl( + db, + type_mapping, + tcx, + visitor, + ), } } }