This commit is contained in:
Ibraheem Ahmed 2025-12-16 16:10:43 +01:00 committed by GitHub
commit 0b872f8f92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 232 additions and 59 deletions

View File

@ -542,8 +542,7 @@ e: list[Any] | None = [1]
reveal_type(e) # revealed: list[Any]
f: list[Any] | None = f2(1)
# TODO: Better constraint solver.
reveal_type(f) # revealed: list[int] | None
reveal_type(f) # revealed: list[Any] | None
g: list[Any] | dict[Any, Any] = f3(1)
# TODO: Better constraint solver.
@ -600,6 +599,48 @@ reveal_type(x7) # revealed: Contravariant[Any]
reveal_type(x8) # revealed: Invariant[Any]
```
## Declared type preference sees through subtyping
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Any, Iterable, Literal, MutableSequence, Sequence
x1: Sequence[Any] = [1, 2, 3]
reveal_type(x1) # revealed: list[Any]
x2: MutableSequence[Any] = [1, 2, 3]
reveal_type(x2) # revealed: list[Any]
x3: Iterable[Any] = [1, 2, 3]
reveal_type(x3) # revealed: list[Any]
class X[T]:
value: T
def __init__(self, value: T): ...
class A[T](X[T]): ...
def a[T](value: T) -> A[T]:
return A(value)
x4: A[object] = A(1)
reveal_type(x4) # revealed: A[object]
x5: X[object] = A(1)
reveal_type(x5) # revealed: A[object]
x6: X[object] | None = A(1)
reveal_type(x6) # revealed: A[object]
x7: X[object] | None = a(1)
reveal_type(x7) # revealed: A[object]
```
## Narrow generic unions
```toml

View File

@ -341,3 +341,58 @@ reveal_type(x21) # revealed: X[Literal[1]]
x22: X[Literal[1]] | None = x(1)
reveal_type(x22) # revealed: X[Literal[1]]
```
## Literal annotations see through subtyping
```py
from typing import Any, Iterable, Literal, MutableSequence, Sequence
x1: Sequence[Literal[1, 2, 3]] = [1, 2, 3]
reveal_type(x1) # revealed: list[Literal[1, 2, 3]]
x2: MutableSequence[Literal[1, 2, 3]] = [1, 2, 3]
reveal_type(x2) # revealed: list[Literal[1, 2, 3]]
x3: Iterable[Literal[1, 2, 3]] = [1, 2, 3]
reveal_type(x3) # revealed: list[Literal[1, 2, 3]]
class Sup1[T]:
value: T
class Sub1[T](Sup1[T]): ...
def sub1[T](value: T) -> Sub1[T]:
return Sub1()
x4: Sub1[Literal[1]] = sub1(1)
reveal_type(x4) # revealed: Sub1[Literal[1]]
x5: Sup1[Literal[1]] = sub1(1)
reveal_type(x5) # revealed: Sub1[Literal[1]]
x6: Sup1[Literal[1]] | None = sub1(1)
reveal_type(x6) # revealed: Sub1[Literal[1]]
x7: Sup1[Literal[1]] | None = sub1(1)
reveal_type(x7) # revealed: Sub1[Literal[1]]
class Sup2A[T, U]:
value: tuple[T, U]
class Sup2B[T, U]:
value: tuple[T, U]
class Sub2[T, U](Sup2A[T, Any], Sup2B[Any, U]): ...
def sub2[T, U](x: T, y: U) -> Sub2[T, U]:
return Sub2()
x8 = sub2(1, 2)
reveal_type(x8) # revealed: Sub2[int, int]
x9: Sup2A[Literal[1], Literal[2]] = sub2(1, 2)
reveal_type(x9) # revealed: Sub2[Literal[1], int]
x10: Sup2B[Literal[1], Literal[2]] = sub2(1, 2)
reveal_type(x10) # revealed: Sub2[int, Literal[2]]
```

View File

@ -57,6 +57,9 @@ reveal_type(tuple((1, 2))) # revealed: tuple[Literal[1], Literal[2]]
reveal_type(tuple([1])) # revealed: tuple[Unknown | int, ...]
x1: tuple[int, ...] = tuple([1])
reveal_type(x1) # revealed: tuple[int, ...]
# error: [invalid-argument-type]
reveal_type(tuple[int]([1])) # revealed: tuple[int]

View File

@ -1083,7 +1083,10 @@ impl<'db> Type<'db> {
}
/// If this type is a class instance, returns its specialization.
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
pub(crate) fn class_specialization(
self,
db: &'db dyn Db,
) -> Option<(ClassLiteral<'db>, Specialization<'db>)> {
self.specialization_of_optional(db, None)
}
@ -1094,15 +1097,17 @@ impl<'db> Type<'db> {
expected_class: ClassLiteral<'_>,
) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, Some(expected_class))
.map(|(_, specialization)| specialization)
}
fn specialization_of_optional(
self,
db: &'db dyn Db,
expected_class: Option<ClassLiteral<'_>>,
) -> Option<Specialization<'db>> {
) -> Option<(ClassLiteral<'db>, Specialization<'db>)> {
let class_type = match self {
Type::NominalInstance(instance) => instance,
Type::ProtocolInstance(instance) => instance.to_nominal_instance()?,
Type::TypeAlias(alias) => alias.value_type(db).as_nominal_instance()?,
_ => return None,
}
@ -1113,7 +1118,49 @@ impl<'db> Type<'db> {
return None;
}
specialization
Some((class_literal, specialization?))
}
/// Given a type variable `T` from the generic context of a class `C`:
/// - If `self` is a specialized instance of `C`, returns the type assigned to `T` on `self`.
/// - If `self` is a specialized instance of some class `A[T]`, and `C[T]` is a subclass of
/// `A[T]`, returns the type assigned to `T` on `self`.
pub(crate) fn find_type_var_from(
self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
class: ClassLiteral<'db>,
) -> Option<Type<'db>> {
if let Some(specialization) = self.specialization_of(db, class) {
return specialization.get(db, bound_typevar);
}
// TODO: We should use the constraint solver here to determine the type mappings for more
// complex subtyping relationships, e.g., callables, protocols, or unions containing multiple
// generic elements.
for base in class.iter_mro(db, None).skip(1) {
let Some((base, Some(base_specialization))) =
base.into_class().map(|class| class.class_literal(db))
else {
continue;
};
if let Some(specialization) = self.specialization_of(db, base) {
for (base_typevar, base_ty) in base_specialization
.generic_context(db)
.variables(db)
.zip(base_specialization.types(db))
{
if *base_ty == Type::TypeVar(bound_typevar) {
return specialization.get(db, base_typevar);
}
}
return None;
}
}
None
}
/// Returns the top materialization (or upper bound materialization) of this type, which is the
@ -4035,23 +4082,26 @@ impl<'db> Type<'db> {
return;
};
let tcx_specialization = tcx.annotation.and_then(|tcx| {
tcx.filter_union(db, |ty| ty.specialization_of(db, class_literal).is_some())
.specialization_of(db, class_literal)
});
for (typevar, ty) in specialization
for (type_var, ty) in specialization
.generic_context(db)
.variables(db)
.zip(specialization.types(db))
{
let variance = typevar.variance_with_polarity(db, polarity);
let tcx = TypeContext::new(tcx_specialization.and_then(|spec| spec.get(db, typevar)));
let variance = type_var.variance_with_polarity(db, polarity);
let narrowed_tcx = tcx.and_then(|annotation| match annotation {
Type::Union(union) => union
.elements(db)
.iter()
.filter_map(|ty| ty.find_type_var_from(db, type_var, class_literal))
.exactly_one()
.ok(),
_ => annotation.find_type_var_from(db, type_var, class_literal),
});
f(typevar, *ty, variance, tcx);
f(type_var, *ty, variance, narrowed_tcx);
visitor.visit(*ty, || {
ty.visit_specialization_impl(db, tcx, variance, f, visitor);
ty.visit_specialization_impl(db, narrowed_tcx, variance, f, visitor);
});
}
}
@ -6249,30 +6299,35 @@ impl<'db> Type<'db> {
}
Some(KnownClass::Tuple) => {
let object = Type::object();
let element_ty =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);
// ```py
// class tuple:
// class tuple(Sequence[_T_co]):
// @overload
// def __new__(cls) -> tuple[()]: ...
// @overload
// def __new__(cls, iterable: Iterable[object]) -> tuple[object, ...]: ...
// def __new__(cls, iterable: Iterable[_T_co]) -> tuple[_T_co, ...]: ...
// ```
CallableBinding::from_overloads(
self,
[
Signature::new(Parameters::empty(), Some(Type::empty_tuple(db))),
Signature::new(
Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [element_ty])),
Parameters::new(
db,
[Parameter::positional_only(Some(Name::new_static(
"iterable",
)))
.with_annotated_type(
KnownClass::Iterable.to_specialized_instance(db, [object]),
KnownClass::Iterable.to_specialized_instance(
db,
[Type::TypeVar(element_ty)],
),
)],
),
Some(Type::homogeneous_tuple(db, object)),
Some(Type::homogeneous_tuple(db, Type::TypeVar(element_ty))),
),
],
)

View File

@ -321,7 +321,7 @@ impl<'db> BoundSuperType<'db> {
Type::NominalInstance(instance) => SuperOwnerKind::Instance(instance),
Type::ProtocolInstance(protocol) => {
if let Some(nominal_instance) = protocol.as_nominal_type() {
if let Some(nominal_instance) = protocol.to_nominal_instance() {
SuperOwnerKind::Instance(nominal_instance)
} else {
return Err(BoundSuperError::AbstractOwnerType {

View File

@ -3001,10 +3001,40 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
let (tcx_class, tcx_specialization) = tcx
.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
.class_specialization(self.db)?;
builder.infer(return_ty, tcx).ok()?;
let Some((return_class, return_specialization)) = return_ty
.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
.class_specialization(self.db)
else {
builder.infer(return_ty, tcx).ok()?;
return Some(builder.type_mappings().clone());
};
// TODO: We should use the constraint solver here to determine the type mappings for more
// complex subtyping relationships, e.g., callables, protocols, or unions containing multiple
// generic elements.
for base in return_class.iter_mro(self.db, Some(return_specialization)) {
let Some((base_class, Some(base_specialization))) =
base.into_class().map(|class| class.class_literal(self.db))
else {
continue;
};
if base_class == tcx_class {
for (base_ty, tcx_ty) in std::iter::zip(
base_specialization.types(self.db),
tcx_specialization.types(self.db),
) {
builder.infer(*base_ty, *tcx_ty).ok()?;
}
break;
}
}
Some(builder.type_mappings().clone())
});

View File

@ -381,6 +381,12 @@ impl<'db> TypeContext<'db> {
}
}
pub(crate) fn and_then(self, f: impl FnOnce(Type<'db>) -> Option<Type<'db>>) -> Self {
Self {
annotation: self.annotation.and_then(f),
}
}
pub(crate) fn is_typealias(&self) -> bool {
self.annotation
.is_some_and(|ty| ty.is_typealias_special_form())

View File

@ -7832,41 +7832,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
annotation.filter_disjoint_elements(self.db(), collection_ty, inferable)
});
// Extract the annotated type of `T`, if provided.
let annotated_elt_tys = tcx
.known_specialization(self.db(), collection_class)
.map(|specialization| specialization.types(self.db()));
// Create a set of constraints to infer a precise type for `T`.
let mut builder = SpecializationBuilder::new(self.db(), inferable);
match annotated_elt_tys {
// The annotated type acts as a constraint for `T`.
//
// Note that we infer the annotated type _before_ the elements, to more closely match the
// order of any unions as written in the type annotation.
Some(annotated_elt_tys) => {
for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys.clone(), annotated_elt_tys) {
builder
.infer(Type::TypeVar(elt_ty), *annotated_elt_ty)
.ok()?;
}
}
for elt_ty in elt_tys.clone() {
let elt_tcx = tcx
.annotation
// The annotated type acts as a constraint for `T`.
//
// Note that we infer the annotated type _before_ the elements, to more closely match the
// order of any unions as written in the type annotation.
.and_then(|tcx| tcx.find_type_var_from(self.db(), elt_ty, class_literal))
// If a valid type annotation was not provided, avoid restricting the type of the collection
// by unioning the inferred type with `Unknown`.
.unwrap_or(Type::unknown());
// If a valid type annotation was not provided, avoid restricting the type of the collection
// by unioning the inferred type with `Unknown`.
None => {
for elt_ty in elt_tys.clone() {
builder.infer(Type::TypeVar(elt_ty), Type::unknown()).ok()?;
}
}
builder.infer(Type::TypeVar(elt_ty), elt_tcx).ok()?;
}
let elt_tcxs = match annotated_elt_tys {
None => Either::Left(iter::repeat(TypeContext::default())),
Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))),
};
for elts in elts {
// An unpacking expression for a dictionary.
if let &[None, Some(value)] = elts.as_slice() {
@ -7889,10 +7872,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// The inferred type of each element acts as an additional constraint on `T`.
for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone())
{
for (elt, elt_ty) in iter::zip(elts, elt_tys.clone()) {
let Some(elt) = elt else { continue };
let elt_tcx =
tcx.and_then(|tcx| tcx.find_type_var_from(self.db(), elt_ty, class_literal));
let inferred_elt_ty = infer_elt_expression(self, elt, elt_tcx);
// Simplify the inference based on the declared type of the element.
@ -8392,7 +8376,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
call_expression: &ast::ExprCall,
tcx: TypeContext<'db>,
) -> Type<'db> {
// TODO: Use the type context for more precise inference.
let callable_type =
self.infer_maybe_standalone_expression(&call_expression.func, TypeContext::default());

View File

@ -165,7 +165,7 @@ impl<'db> Type<'db> {
// This matches the behaviour of other type checkers, and is required for us to
// recognise `str` as a subtype of `Container[str]`.
structurally_satisfied.or(db, || {
let Some(nominal_instance) = protocol.as_nominal_type() else {
let Some(nominal_instance) = protocol.to_nominal_instance() else {
return ConstraintSet::from(false);
};
@ -175,7 +175,7 @@ impl<'db> Type<'db> {
// `Q`'s members in a Liskov-incompatible way.
let type_to_test = self
.as_protocol_instance()
.and_then(ProtocolInstanceType::as_nominal_type)
.and_then(ProtocolInstanceType::to_nominal_instance)
.map(Type::NominalInstance)
.unwrap_or(self);
@ -658,7 +658,7 @@ impl<'db> ProtocolInstanceType<'db> {
/// If this is a synthesized protocol that does not correspond to a class definition
/// in source code, return `None`. These are "pure" abstract types, that cannot be
/// treated in a nominal way.
pub(super) fn as_nominal_type(self) -> Option<NominalInstanceType<'db>> {
pub(super) fn to_nominal_instance(self) -> Option<NominalInstanceType<'db>> {
match self.inner {
Protocol::FromClass(class) => {
Some(NominalInstanceType(NominalInstanceInner::NonTuple(*class)))