This commit is contained in:
Douglas Creager 2025-12-16 16:37:01 -05:00 committed by GitHub
commit 3603efbc74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 139 additions and 103 deletions

View File

@ -594,7 +594,8 @@ x6: Covariant[Any] = covariant(1)
x7: Contravariant[Any] = contravariant(1) x7: Contravariant[Any] = contravariant(1)
x8: Invariant[Any] = invariant(1) x8: Invariant[Any] = invariant(1)
reveal_type(x5) # revealed: Bivariant[Any] # TODO: revealed: Bivariant[Any]
reveal_type(x5) # revealed: Bivariant[Literal[1]]
reveal_type(x6) # revealed: Covariant[Any] reveal_type(x6) # revealed: Covariant[Any]
reveal_type(x7) # revealed: Contravariant[Any] reveal_type(x7) # revealed: Contravariant[Any]
reveal_type(x8) # revealed: Invariant[Any] reveal_type(x8) # revealed: Invariant[Any]

View File

@ -17,8 +17,7 @@ from datetime import time
t = time(12, 0, 0) t = time(12, 0, 0)
t = replace(t, minute=30) t = replace(t, minute=30)
# TODO: this should be `time`, once we support specialization of generic protocols reveal_type(t) # revealed: time
reveal_type(t) # revealed: Unknown
``` ```
## The `__replace__` protocol ## The `__replace__` protocol
@ -48,8 +47,7 @@ b = a.__replace__(x=3, y=4)
reveal_type(b) # revealed: Point reveal_type(b) # revealed: Point
b = replace(a, x=3, y=4) b = replace(a, x=3, y=4)
# TODO: this should be `Point`, once we support specialization of generic protocols reveal_type(b) # revealed: Point
reveal_type(b) # revealed: Unknown
``` ```
A call to `replace` does not require all keyword arguments: A call to `replace` does not require all keyword arguments:
@ -59,8 +57,7 @@ c = a.__replace__(y=4)
reveal_type(c) # revealed: Point reveal_type(c) # revealed: Point
d = replace(a, y=4) d = replace(a, y=4)
# TODO: this should be `Point`, once we support specialization of generic protocols reveal_type(d) # revealed: Point
reveal_type(d) # revealed: Unknown
``` ```
Invalid calls to `__replace__` or `replace` will raise an error: Invalid calls to `__replace__` or `replace` will raise an error:

View File

@ -526,7 +526,10 @@ def test_seq(x: Sequence[T]) -> Sequence[T]:
def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]):
reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]]
reveal_type(test_seq(t2)) # revealed: Sequence[int | str] reveal_type(test_seq(t2)) # revealed: Sequence[int | str]
reveal_type(test_seq(t3)) # revealed: Sequence[Never] # TODO: The return type here is wrong, because we end up creating a constraint (Never ≤ T),
# which we confuse with "T has no lower bound".
# TODO: revealed: Sequence[Never]
reveal_type(test_seq(t3)) # revealed: Sequence[Unknown]
``` ```
### `__init__` is itself generic ### `__init__` is itself generic
@ -561,7 +564,7 @@ from typing_extensions import overload, Generic, TypeVar
from ty_extensions import generic_context, into_callable from ty_extensions import generic_context, into_callable
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U") U = TypeVar("U", covariant=True)
class C(Generic[T]): class C(Generic[T]):
@overload @overload
@ -611,9 +614,9 @@ reveal_type(generic_context(D))
# revealed: ty_extensions.GenericContext[T@D, U@D] # revealed: ty_extensions.GenericContext[T@D, U@D]
reveal_type(generic_context(into_callable(D))) reveal_type(generic_context(into_callable(D)))
reveal_type(D("string")) # revealed: D[str, str] reveal_type(D("string")) # revealed: D[str, Literal["string"]]
reveal_type(D(1)) # revealed: D[str, int] reveal_type(D(1)) # revealed: D[str, Literal[1]]
reveal_type(D(1, "string")) # revealed: D[int, str] reveal_type(D(1, "string")) # revealed: D[int, Literal["string"]]
``` ```
### Synthesized methods with dataclasses ### Synthesized methods with dataclasses

View File

@ -89,13 +89,11 @@ def takes_in_protocol(x: CanIndex[T]) -> T:
def deep_list(x: list[str]) -> None: def deep_list(x: list[str]) -> None:
reveal_type(takes_in_list(x)) # revealed: list[str] reveal_type(takes_in_list(x)) # revealed: list[str]
# TODO: revealed: str reveal_type(takes_in_protocol(x)) # revealed: str
reveal_type(takes_in_protocol(x)) # revealed: Unknown
def deeper_list(x: list[set[str]]) -> None: def deeper_list(x: list[set[str]]) -> None:
reveal_type(takes_in_list(x)) # revealed: list[set[str]] reveal_type(takes_in_list(x)) # revealed: list[set[str]]
# TODO: revealed: set[str] reveal_type(takes_in_protocol(x)) # revealed: set[str]
reveal_type(takes_in_protocol(x)) # revealed: Unknown
def deep_explicit(x: ExplicitlyImplements[str]) -> None: def deep_explicit(x: ExplicitlyImplements[str]) -> None:
reveal_type(takes_in_protocol(x)) # revealed: str reveal_type(takes_in_protocol(x)) # revealed: str
@ -116,12 +114,10 @@ class Sub(list[int]): ...
class GenericSub(list[T]): ... class GenericSub(list[T]): ...
reveal_type(takes_in_list(Sub())) # revealed: list[int] reveal_type(takes_in_list(Sub())) # revealed: list[int]
# TODO: revealed: int reveal_type(takes_in_protocol(Sub())) # revealed: int
reveal_type(takes_in_protocol(Sub())) # revealed: Unknown
reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str] reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str]
# TODO: revealed: str reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: str
reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
class ExplicitSub(ExplicitlyImplements[int]): ... class ExplicitSub(ExplicitlyImplements[int]): ...
class ExplicitGenericSub(ExplicitlyImplements[T]): ... class ExplicitGenericSub(ExplicitlyImplements[T]): ...
@ -409,6 +405,10 @@ reveal_type(extract_t(Q[str]())) # revealed: str
Passing anything else results in an error: Passing anything else results in an error:
```py ```py
# TODO: We currently get an error for the specialization failure, and then another because the
# argument is not assignable to the (default-specialized) parameter annotation. We really only need
# one of them.
# error: [invalid-argument-type]
# error: [invalid-argument-type] # error: [invalid-argument-type]
reveal_type(extract_t([1, 2])) # revealed: Unknown reveal_type(extract_t([1, 2])) # revealed: Unknown
``` ```
@ -470,6 +470,10 @@ reveal_type(extract_optional_t(P[int]())) # revealed: int
Passing anything else results in an error: Passing anything else results in an error:
```py ```py
# TODO: We currently get an error for the specialization failure, and then another because the
# argument is not assignable to the (default-specialized) parameter annotation. We really only need
# one of them.
# error: [invalid-argument-type]
# error: [invalid-argument-type] # error: [invalid-argument-type]
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
``` ```

View File

@ -464,7 +464,10 @@ def test_seq[T](x: Sequence[T]) -> Sequence[T]:
def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]):
reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]]
reveal_type(test_seq(t2)) # revealed: Sequence[int | str] reveal_type(test_seq(t2)) # revealed: Sequence[int | str]
reveal_type(test_seq(t3)) # revealed: Sequence[Never] # TODO: The return type here is wrong, because we end up creating a constraint (Never ≤ T),
# which we confuse with "T has no lower bound".
# TODO: revealed: Sequence[Never]
reveal_type(test_seq(t3)) # revealed: Sequence[Unknown]
``` ```
### `__init__` is itself generic ### `__init__` is itself generic
@ -538,6 +541,10 @@ C[None](b"bytes") # error: [no-matching-overload]
C[None](12) C[None](12)
class D[T, U]: class D[T, U]:
# we need to use the type variable or else the class is bivariant in T, and
# specializations become meaningless
x: T
@overload @overload
def __init__(self: "D[str, U]", u: U) -> None: ... def __init__(self: "D[str, U]", u: U) -> None: ...
@overload @overload
@ -551,7 +558,7 @@ reveal_type(generic_context(into_callable(D)))
reveal_type(D("string")) # revealed: D[str, Literal["string"]] reveal_type(D("string")) # revealed: D[str, Literal["string"]]
reveal_type(D(1)) # revealed: D[str, Literal[1]] reveal_type(D(1)) # revealed: D[str, Literal[1]]
reveal_type(D(1, "string")) # revealed: D[Literal[1], Literal["string"]] reveal_type(D(1, "string")) # revealed: D[int, Literal["string"]]
``` ```
### Synthesized methods with dataclasses ### Synthesized methods with dataclasses

View File

@ -84,13 +84,11 @@ def takes_in_protocol[T](x: CanIndex[T]) -> T:
def deep_list(x: list[str]) -> None: def deep_list(x: list[str]) -> None:
reveal_type(takes_in_list(x)) # revealed: list[str] reveal_type(takes_in_list(x)) # revealed: list[str]
# TODO: revealed: str reveal_type(takes_in_protocol(x)) # revealed: str
reveal_type(takes_in_protocol(x)) # revealed: Unknown
def deeper_list(x: list[set[str]]) -> None: def deeper_list(x: list[set[str]]) -> None:
reveal_type(takes_in_list(x)) # revealed: list[set[str]] reveal_type(takes_in_list(x)) # revealed: list[set[str]]
# TODO: revealed: set[str] reveal_type(takes_in_protocol(x)) # revealed: set[str]
reveal_type(takes_in_protocol(x)) # revealed: Unknown
def deep_explicit(x: ExplicitlyImplements[str]) -> None: def deep_explicit(x: ExplicitlyImplements[str]) -> None:
reveal_type(takes_in_protocol(x)) # revealed: str reveal_type(takes_in_protocol(x)) # revealed: str
@ -111,12 +109,10 @@ class Sub(list[int]): ...
class GenericSub[T](list[T]): ... class GenericSub[T](list[T]): ...
reveal_type(takes_in_list(Sub())) # revealed: list[int] reveal_type(takes_in_list(Sub())) # revealed: list[int]
# TODO: revealed: int reveal_type(takes_in_protocol(Sub())) # revealed: int
reveal_type(takes_in_protocol(Sub())) # revealed: Unknown
reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str] reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str]
# TODO: revealed: str reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: str
reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
class ExplicitSub(ExplicitlyImplements[int]): ... class ExplicitSub(ExplicitlyImplements[int]): ...
class ExplicitGenericSub[T](ExplicitlyImplements[T]): ... class ExplicitGenericSub[T](ExplicitlyImplements[T]): ...
@ -362,6 +358,10 @@ reveal_type(extract_t(Q[str]())) # revealed: str
Passing anything else results in an error: Passing anything else results in an error:
```py ```py
# TODO: We currently get an error for the specialization failure, and then another because the
# argument is not assignable to the (default-specialized) parameter annotation. We really only need
# one of them.
# error: [invalid-argument-type]
# error: [invalid-argument-type] # error: [invalid-argument-type]
reveal_type(extract_t([1, 2])) # revealed: Unknown reveal_type(extract_t([1, 2])) # revealed: Unknown
``` ```
@ -421,6 +421,10 @@ reveal_type(extract_optional_t(P[int]())) # revealed: int
Passing anything else results in an error: Passing anything else results in an error:
```py ```py
# TODO: We currently get an error for the specialization failure, and then another because the
# argument is not assignable to the (default-specialized) parameter annotation. We really only need
# one of them.
# error: [invalid-argument-type]
# error: [invalid-argument-type] # error: [invalid-argument-type]
reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown
``` ```

View File

@ -4445,7 +4445,6 @@ impl<'db> BindingError<'db> {
return; return;
}; };
let typevar = error.bound_typevar().typevar(context.db());
let argument_type = error.argument_type(); let argument_type = error.argument_type();
let argument_ty_display = argument_type.display(context.db()); let argument_ty_display = argument_type.display(context.db());
@ -4458,21 +4457,51 @@ impl<'db> BindingError<'db> {
} }
)); ));
let typevar_name = typevar.name(context.db());
match error { match error {
SpecializationError::MismatchedBound { .. } => { SpecializationError::NoSolution { parameter, .. } => {
diag.set_primary_message(format_args!("Argument type `{argument_ty_display}` does not satisfy upper bound `{}` of type variable `{typevar_name}`", diag.set_primary_message(format_args!(
typevar.upper_bound(context.db()).expect("type variable should have an upper bound if this error occurs").display(context.db()) "Argument type `{argument_ty_display}` does not \
satisfy generic parameter annotation `{}",
parameter.display(context.db()),
)); ));
} }
SpecializationError::MismatchedConstraint { .. } => { SpecializationError::MismatchedBound { bound_typevar, .. } => {
diag.set_primary_message(format_args!("Argument type `{argument_ty_display}` does not satisfy constraints ({}) of type variable `{typevar_name}`", let typevar = bound_typevar.typevar(context.db());
typevar.constraints(context.db()).expect("type variable should have constraints if this error occurs").iter().map(|ty| format!("`{}`", ty.display(context.db()))).join(", ") let typevar_name = typevar.name(context.db());
diag.set_primary_message(format_args!(
"Argument type `{argument_ty_display}` does not \
satisfy upper bound `{}` of type variable `{typevar_name}`",
typevar
.upper_bound(context.db())
.expect(
"type variable should have an upper bound if this error occurs"
)
.display(context.db())
));
}
SpecializationError::MismatchedConstraint { bound_typevar, .. } => {
let typevar = bound_typevar.typevar(context.db());
let typevar_name = typevar.name(context.db());
diag.set_primary_message(format_args!(
"Argument type `{argument_ty_display}` does not \
satisfy constraints ({}) of type variable `{typevar_name}`",
typevar
.constraints(context.db())
.expect(
"type variable should have constraints if this error occurs"
)
.iter()
.format_with(", ", |ty, f| f(&format_args!(
"`{}`",
ty.display(context.db())
)))
)); ));
} }
} }
if let Some(typevar_definition) = typevar.definition(context.db()) { if let Some(typevar_definition) = error.bound_typevar().and_then(|bound_typevar| {
bound_typevar.typevar(context.db()).definition(context.db())
}) {
let module = parsed_module(context.db(), typevar_definition.file(context.db())) let module = parsed_module(context.db(), typevar_definition.file(context.db()))
.load(context.db()); .load(context.db());
let typevar_range = typevar_definition.full_range(context.db(), &module); let typevar_range = typevar_definition.full_range(context.db(), &module);

View File

@ -72,6 +72,7 @@ use std::fmt::Display;
use std::ops::Range; use std::ops::Range;
use itertools::Itertools; use itertools::Itertools;
use ordermap::map::Entry;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use salsa::plumbing::AsId; use salsa::plumbing::AsId;
@ -3415,9 +3416,11 @@ impl<'db> PathAssignments<'db> {
); );
return Err(PathAssignmentConflict); return Err(PathAssignmentConflict);
} }
if self.assignments.insert(assignment, source_order).is_some() {
return Ok(()); match self.assignments.entry(assignment) {
} Entry::Vacant(entry) => entry.insert(source_order),
Entry::Occupied(_) => return Ok(()),
};
// Then use our sequents to add additional facts that we know to be true. We currently // Then use our sequents to add additional facts that we know to be true. We currently
// reuse the `source_order` of the "real" constraint passed into `walk_edge` when we add // reuse the `source_order` of the "real" constraint passed into `walk_edge` when we add

View File

@ -9,10 +9,7 @@ use rustc_hash::{FxHashMap, FxHashSet};
use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::Definition;
use crate::semantic_index::scope::{FileScopeId, NodeWithScopeKind, ScopeId}; use crate::semantic_index::scope::{FileScopeId, NodeWithScopeKind, ScopeId};
use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::class::ClassType;
use crate::types::class_base::ClassBase;
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::instance::{Protocol, ProtocolInstanceType};
use crate::types::signatures::Parameters; use crate::types::signatures::Parameters;
use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type};
use crate::types::variance::VarianceInferable; use crate::types::variance::VarianceInferable;
@ -1590,7 +1587,6 @@ impl<'db> SpecializationBuilder<'db> {
upper: Vec<Type<'db>>, upper: Vec<Type<'db>>,
} }
let constraints = constraints.limit_to_valid_specializations(self.db);
let mut sorted_paths = Vec::new(); let mut sorted_paths = Vec::new();
constraints.for_each_path(self.db, |path| { constraints.for_each_path(self.db, |path| {
let mut path: Vec<_> = path.positive_constraints().collect(); let mut path: Vec<_> = path.positive_constraints().collect();
@ -1887,49 +1883,18 @@ impl<'db> SpecializationBuilder<'db> {
return Ok(()); return Ok(());
} }
// Extract formal_alias if this is a generic class let when = actual
let formal_alias = match formal { .when_constraint_set_assignable_to(self.db, formal, self.inferable)
Type::NominalInstance(formal_nominal) => { .limit_to_valid_specializations(self.db);
formal_nominal.class(self.db).into_generic_alias() if when.is_never_satisfied(self.db)
} && (formal.has_typevar(self.db) || actual.has_typevar(self.db))
// TODO: This will only handle classes that explicit implement a generic protocol {
// by listing it as a base class. To handle classes that implicitly implement a return Err(SpecializationError::NoSolution {
// generic protocol, we will need to check the types of the protocol members to be parameter: formal,
// able to infer the specialization of the protocol that the class implements. argument: actual,
Type::ProtocolInstance(ProtocolInstanceType { });
inner: Protocol::FromClass(class),
..
}) => class.into_generic_alias(),
_ => None,
};
if let Some(formal_alias) = formal_alias {
let formal_origin = formal_alias.origin(self.db);
for base in actual_nominal.class(self.db).iter_mro(self.db) {
let ClassBase::Class(ClassType::Generic(base_alias)) = base else {
continue;
};
if formal_origin != base_alias.origin(self.db) {
continue;
}
let generic_context = formal_alias
.specialization(self.db)
.generic_context(self.db)
.variables(self.db);
let formal_specialization =
formal_alias.specialization(self.db).types(self.db);
let base_specialization = base_alias.specialization(self.db).types(self.db);
for (typevar, formal_ty, base_ty) in itertools::izip!(
generic_context,
formal_specialization,
base_specialization
) {
let variance = typevar.variance_with_polarity(self.db, polarity);
self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?;
}
return Ok(());
}
} }
self.add_type_mappings_from_constraint_set(formal, when, &mut f);
} }
(Type::Callable(formal_callable), _) => { (Type::Callable(formal_callable), _) => {
@ -1948,7 +1913,8 @@ impl<'db> SpecializationBuilder<'db> {
self.db, self.db,
formal_callable, formal_callable,
self.inferable, self.inferable,
); )
.limit_to_valid_specializations(self.db);
self.add_type_mappings_from_constraint_set(formal, when, &mut f); self.add_type_mappings_from_constraint_set(formal, when, &mut f);
} else { } else {
for actual_signature in &actual_callable.signatures(self.db).overloads { for actual_signature in &actual_callable.signatures(self.db).overloads {
@ -1957,7 +1923,8 @@ impl<'db> SpecializationBuilder<'db> {
self.db, self.db,
formal_callable, formal_callable,
self.inferable, self.inferable,
); )
.limit_to_valid_specializations(self.db);
self.add_type_mappings_from_constraint_set(formal, when, &mut f); self.add_type_mappings_from_constraint_set(formal, when, &mut f);
} }
} }
@ -1974,6 +1941,10 @@ impl<'db> SpecializationBuilder<'db> {
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum SpecializationError<'db> { pub(crate) enum SpecializationError<'db> {
NoSolution {
parameter: Type<'db>,
argument: Type<'db>,
},
MismatchedBound { MismatchedBound {
bound_typevar: BoundTypeVarInstance<'db>, bound_typevar: BoundTypeVarInstance<'db>,
argument: Type<'db>, argument: Type<'db>,
@ -1985,15 +1956,17 @@ pub(crate) enum SpecializationError<'db> {
} }
impl<'db> SpecializationError<'db> { impl<'db> SpecializationError<'db> {
pub(crate) fn bound_typevar(&self) -> BoundTypeVarInstance<'db> { pub(crate) fn bound_typevar(&self) -> Option<BoundTypeVarInstance<'db>> {
match self { match self {
Self::MismatchedBound { bound_typevar, .. } => *bound_typevar, Self::NoSolution { .. } => None,
Self::MismatchedConstraint { bound_typevar, .. } => *bound_typevar, Self::MismatchedBound { bound_typevar, .. } => Some(*bound_typevar),
Self::MismatchedConstraint { bound_typevar, .. } => Some(*bound_typevar),
} }
} }
pub(crate) fn argument_type(&self) -> Type<'db> { pub(crate) fn argument_type(&self) -> Type<'db> {
match self { match self {
Self::NoSolution { argument, .. } => *argument,
Self::MismatchedBound { argument, .. } => *argument, Self::MismatchedBound { argument, .. } => *argument,
Self::MismatchedConstraint { argument, .. } => *argument, Self::MismatchedConstraint { argument, .. } => *argument,
} }

View File

@ -133,6 +133,20 @@ impl<'db> Type<'db> {
disjointness_visitor: &IsDisjointVisitor<'db>, disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> { ) -> ConstraintSet<'db> {
let structurally_satisfied = if let Type::ProtocolInstance(self_protocol) = self { let structurally_satisfied = if let Type::ProtocolInstance(self_protocol) = self {
let self_as_nominal = self_protocol.as_nominal_type();
let other_as_nominal = protocol.as_nominal_type();
let nominal_match = match self_as_nominal.zip(other_as_nominal) {
Some((self_as_nominal, other_as_nominal)) => self_as_nominal.has_relation_to_impl(
db,
other_as_nominal,
inferable,
relation,
relation_visitor,
disjointness_visitor,
),
_ => ConstraintSet::from(false),
};
nominal_match.or(db, || {
self_protocol.interface(db).has_relation_to_impl( self_protocol.interface(db).has_relation_to_impl(
db, db,
protocol.interface(db), protocol.interface(db),
@ -141,6 +155,7 @@ impl<'db> Type<'db> {
relation_visitor, relation_visitor,
disjointness_visitor, disjointness_visitor,
) )
})
} else { } else {
protocol protocol
.inner .inner