diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 36f53afe4d..826a336526 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -594,7 +594,8 @@ x6: Covariant[Any] = covariant(1) x7: Contravariant[Any] = contravariant(1) x8: Invariant[Any] = invariant(1) -reveal_type(x5) # revealed: Bivariant[Any] +# TODO: revealed: Bivariant[Any] +reveal_type(x5) # revealed: Bivariant[Literal[1]] reveal_type(x6) # revealed: Covariant[Any] reveal_type(x7) # revealed: Contravariant[Any] reveal_type(x8) # revealed: Invariant[Any] diff --git a/crates/ty_python_semantic/resources/mdtest/call/replace.md b/crates/ty_python_semantic/resources/mdtest/call/replace.md index 8d5a6b55bd..3e9a0be744 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/replace.md +++ b/crates/ty_python_semantic/resources/mdtest/call/replace.md @@ -17,8 +17,7 @@ from datetime import time t = time(12, 0, 0) t = replace(t, minute=30) -# TODO: this should be `time`, once we support specialization of generic protocols -reveal_type(t) # revealed: Unknown +reveal_type(t) # revealed: time ``` ## The `__replace__` protocol @@ -48,8 +47,7 @@ b = a.__replace__(x=3, y=4) reveal_type(b) # revealed: Point b = replace(a, x=3, y=4) -# TODO: this should be `Point`, once we support specialization of generic protocols -reveal_type(b) # revealed: Unknown +reveal_type(b) # revealed: Point ``` A call to `replace` does not require all keyword arguments: @@ -59,8 +57,7 @@ c = a.__replace__(y=4) reveal_type(c) # revealed: Point d = replace(a, y=4) -# TODO: this should be `Point`, once we support specialization of generic protocols -reveal_type(d) # revealed: Unknown +reveal_type(d) # revealed: Point ``` Invalid calls to `__replace__` or `replace` will raise an error: diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md index 30d6a89ec0..f78bff09ca 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md @@ -526,7 +526,10 @@ def test_seq(x: Sequence[T]) -> Sequence[T]: def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t2)) # revealed: Sequence[int | str] - reveal_type(test_seq(t3)) # revealed: Sequence[Never] + # TODO: The return type here is wrong, because we end up creating a constraint (Never ≤ T), + # which we confuse with "T has no lower bound". + # TODO: revealed: Sequence[Never] + reveal_type(test_seq(t3)) # revealed: Sequence[Unknown] ``` ### `__init__` is itself generic @@ -561,7 +564,7 @@ from typing_extensions import overload, Generic, TypeVar from ty_extensions import generic_context, into_callable T = TypeVar("T") -U = TypeVar("U") +U = TypeVar("U", covariant=True) class C(Generic[T]): @overload @@ -611,9 +614,9 @@ reveal_type(generic_context(D)) # revealed: ty_extensions.GenericContext[T@D, U@D] reveal_type(generic_context(into_callable(D))) -reveal_type(D("string")) # revealed: D[str, str] -reveal_type(D(1)) # revealed: D[str, int] -reveal_type(D(1, "string")) # revealed: D[int, str] +reveal_type(D("string")) # revealed: D[str, Literal["string"]] +reveal_type(D(1)) # revealed: D[str, Literal[1]] +reveal_type(D(1, "string")) # revealed: D[int, Literal["string"]] ``` ### Synthesized methods with dataclasses diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index c674f7a9a1..d0a2d65e71 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -89,13 +89,11 @@ def takes_in_protocol(x: CanIndex[T]) -> T: def deep_list(x: list[str]) -> None: reveal_type(takes_in_list(x)) # revealed: list[str] - # TODO: revealed: str - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: str def deeper_list(x: list[set[str]]) -> None: reveal_type(takes_in_list(x)) # revealed: list[set[str]] - # TODO: revealed: set[str] - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: set[str] def deep_explicit(x: ExplicitlyImplements[str]) -> None: reveal_type(takes_in_protocol(x)) # revealed: str @@ -116,12 +114,10 @@ class Sub(list[int]): ... class GenericSub(list[T]): ... reveal_type(takes_in_list(Sub())) # revealed: list[int] -# TODO: revealed: int -reveal_type(takes_in_protocol(Sub())) # revealed: Unknown +reveal_type(takes_in_protocol(Sub())) # revealed: int reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str] -# TODO: revealed: str -reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown +reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: str class ExplicitSub(ExplicitlyImplements[int]): ... class ExplicitGenericSub(ExplicitlyImplements[T]): ... @@ -409,6 +405,10 @@ reveal_type(extract_t(Q[str]())) # revealed: str Passing anything else results in an error: ```py +# TODO: We currently get an error for the specialization failure, and then another because the +# argument is not assignable to the (default-specialized) parameter annotation. We really only need +# one of them. +# error: [invalid-argument-type] # error: [invalid-argument-type] reveal_type(extract_t([1, 2])) # revealed: Unknown ``` @@ -470,6 +470,10 @@ reveal_type(extract_optional_t(P[int]())) # revealed: int Passing anything else results in an error: ```py +# TODO: We currently get an error for the specialization failure, and then another because the +# argument is not assignable to the (default-specialized) parameter annotation. We really only need +# one of them. +# error: [invalid-argument-type] # error: [invalid-argument-type] reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown ``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md index 06680d2168..6acfa5b4d6 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md @@ -464,7 +464,10 @@ def test_seq[T](x: Sequence[T]) -> Sequence[T]: def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t2)) # revealed: Sequence[int | str] - reveal_type(test_seq(t3)) # revealed: Sequence[Never] + # TODO: The return type here is wrong, because we end up creating a constraint (Never ≤ T), + # which we confuse with "T has no lower bound". + # TODO: revealed: Sequence[Never] + reveal_type(test_seq(t3)) # revealed: Sequence[Unknown] ``` ### `__init__` is itself generic @@ -538,6 +541,10 @@ C[None](b"bytes") # error: [no-matching-overload] C[None](12) class D[T, U]: + # we need to use the type variable or else the class is bivariant in T, and + # specializations become meaningless + x: T + @overload def __init__(self: "D[str, U]", u: U) -> None: ... @overload @@ -551,7 +558,7 @@ reveal_type(generic_context(into_callable(D))) reveal_type(D("string")) # revealed: D[str, Literal["string"]] reveal_type(D(1)) # revealed: D[str, Literal[1]] -reveal_type(D(1, "string")) # revealed: D[Literal[1], Literal["string"]] +reveal_type(D(1, "string")) # revealed: D[int, Literal["string"]] ``` ### Synthesized methods with dataclasses diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 8121ce5d26..ffb046c580 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -84,13 +84,11 @@ def takes_in_protocol[T](x: CanIndex[T]) -> T: def deep_list(x: list[str]) -> None: reveal_type(takes_in_list(x)) # revealed: list[str] - # TODO: revealed: str - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: str def deeper_list(x: list[set[str]]) -> None: reveal_type(takes_in_list(x)) # revealed: list[set[str]] - # TODO: revealed: set[str] - reveal_type(takes_in_protocol(x)) # revealed: Unknown + reveal_type(takes_in_protocol(x)) # revealed: set[str] def deep_explicit(x: ExplicitlyImplements[str]) -> None: reveal_type(takes_in_protocol(x)) # revealed: str @@ -111,12 +109,10 @@ class Sub(list[int]): ... class GenericSub[T](list[T]): ... reveal_type(takes_in_list(Sub())) # revealed: list[int] -# TODO: revealed: int -reveal_type(takes_in_protocol(Sub())) # revealed: Unknown +reveal_type(takes_in_protocol(Sub())) # revealed: int reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str] -# TODO: revealed: str -reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown +reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: str class ExplicitSub(ExplicitlyImplements[int]): ... class ExplicitGenericSub[T](ExplicitlyImplements[T]): ... @@ -362,6 +358,10 @@ reveal_type(extract_t(Q[str]())) # revealed: str Passing anything else results in an error: ```py +# TODO: We currently get an error for the specialization failure, and then another because the +# argument is not assignable to the (default-specialized) parameter annotation. We really only need +# one of them. +# error: [invalid-argument-type] # error: [invalid-argument-type] reveal_type(extract_t([1, 2])) # revealed: Unknown ``` @@ -421,6 +421,10 @@ reveal_type(extract_optional_t(P[int]())) # revealed: int Passing anything else results in an error: ```py +# TODO: We currently get an error for the specialization failure, and then another because the +# argument is not assignable to the (default-specialized) parameter annotation. We really only need +# one of them. +# error: [invalid-argument-type] # error: [invalid-argument-type] reveal_type(extract_optional_t(Q[str]())) # revealed: Unknown ``` diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 005013e70b..86ba312ba8 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -4445,7 +4445,6 @@ impl<'db> BindingError<'db> { return; }; - let typevar = error.bound_typevar().typevar(context.db()); let argument_type = error.argument_type(); let argument_ty_display = argument_type.display(context.db()); @@ -4458,21 +4457,51 @@ impl<'db> BindingError<'db> { } )); - let typevar_name = typevar.name(context.db()); match error { - SpecializationError::MismatchedBound { .. } => { - diag.set_primary_message(format_args!("Argument type `{argument_ty_display}` does not satisfy upper bound `{}` of type variable `{typevar_name}`", - typevar.upper_bound(context.db()).expect("type variable should have an upper bound if this error occurs").display(context.db()) - )); + SpecializationError::NoSolution { parameter, .. } => { + diag.set_primary_message(format_args!( + "Argument type `{argument_ty_display}` does not \ + satisfy generic parameter annotation `{}", + parameter.display(context.db()), + )); } - SpecializationError::MismatchedConstraint { .. } => { - diag.set_primary_message(format_args!("Argument type `{argument_ty_display}` does not satisfy constraints ({}) of type variable `{typevar_name}`", - typevar.constraints(context.db()).expect("type variable should have constraints if this error occurs").iter().map(|ty| format!("`{}`", ty.display(context.db()))).join(", ") - )); + SpecializationError::MismatchedBound { bound_typevar, .. } => { + let typevar = bound_typevar.typevar(context.db()); + let typevar_name = typevar.name(context.db()); + diag.set_primary_message(format_args!( + "Argument type `{argument_ty_display}` does not \ + satisfy upper bound `{}` of type variable `{typevar_name}`", + typevar + .upper_bound(context.db()) + .expect( + "type variable should have an upper bound if this error occurs" + ) + .display(context.db()) + )); + } + SpecializationError::MismatchedConstraint { bound_typevar, .. } => { + let typevar = bound_typevar.typevar(context.db()); + let typevar_name = typevar.name(context.db()); + diag.set_primary_message(format_args!( + "Argument type `{argument_ty_display}` does not \ + satisfy constraints ({}) of type variable `{typevar_name}`", + typevar + .constraints(context.db()) + .expect( + "type variable should have constraints if this error occurs" + ) + .iter() + .format_with(", ", |ty, f| f(&format_args!( + "`{}`", + ty.display(context.db()) + ))) + )); } } - if let Some(typevar_definition) = typevar.definition(context.db()) { + if let Some(typevar_definition) = error.bound_typevar().and_then(|bound_typevar| { + bound_typevar.typevar(context.db()).definition(context.db()) + }) { let module = parsed_module(context.db(), typevar_definition.file(context.db())) .load(context.db()); let typevar_range = typevar_definition.full_range(context.db(), &module); diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 77b96bd74b..fa60820296 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -72,6 +72,7 @@ use std::fmt::Display; use std::ops::Range; use itertools::Itertools; +use ordermap::map::Entry; use rustc_hash::{FxHashMap, FxHashSet}; use salsa::plumbing::AsId; @@ -3415,9 +3416,11 @@ impl<'db> PathAssignments<'db> { ); return Err(PathAssignmentConflict); } - if self.assignments.insert(assignment, source_order).is_some() { - return Ok(()); - } + + match self.assignments.entry(assignment) { + Entry::Vacant(entry) => entry.insert(source_order), + Entry::Occupied(_) => return Ok(()), + }; // Then use our sequents to add additional facts that we know to be true. We currently // reuse the `source_order` of the "real" constraint passed into `walk_edge` when we add diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index c012ab09f6..90ab8ed824 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -9,10 +9,7 @@ use rustc_hash::{FxHashMap, FxHashSet}; use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::{FileScopeId, NodeWithScopeKind, ScopeId}; use crate::semantic_index::{SemanticIndex, semantic_index}; -use crate::types::class::ClassType; -use crate::types::class_base::ClassBase; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; -use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::Parameters; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::variance::VarianceInferable; @@ -1590,7 +1587,6 @@ impl<'db> SpecializationBuilder<'db> { upper: Vec>, } - let constraints = constraints.limit_to_valid_specializations(self.db); let mut sorted_paths = Vec::new(); constraints.for_each_path(self.db, |path| { let mut path: Vec<_> = path.positive_constraints().collect(); @@ -1887,49 +1883,18 @@ impl<'db> SpecializationBuilder<'db> { return Ok(()); } - // Extract formal_alias if this is a generic class - let formal_alias = match formal { - Type::NominalInstance(formal_nominal) => { - formal_nominal.class(self.db).into_generic_alias() - } - // TODO: This will only handle classes that explicit implement a generic protocol - // by listing it as a base class. To handle classes that implicitly implement a - // generic protocol, we will need to check the types of the protocol members to be - // able to infer the specialization of the protocol that the class implements. - Type::ProtocolInstance(ProtocolInstanceType { - inner: Protocol::FromClass(class), - .. - }) => class.into_generic_alias(), - _ => None, - }; - - if let Some(formal_alias) = formal_alias { - let formal_origin = formal_alias.origin(self.db); - for base in actual_nominal.class(self.db).iter_mro(self.db) { - let ClassBase::Class(ClassType::Generic(base_alias)) = base else { - continue; - }; - if formal_origin != base_alias.origin(self.db) { - continue; - } - let generic_context = formal_alias - .specialization(self.db) - .generic_context(self.db) - .variables(self.db); - let formal_specialization = - formal_alias.specialization(self.db).types(self.db); - let base_specialization = base_alias.specialization(self.db).types(self.db); - for (typevar, formal_ty, base_ty) in itertools::izip!( - generic_context, - formal_specialization, - base_specialization - ) { - let variance = typevar.variance_with_polarity(self.db, polarity); - self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?; - } - return Ok(()); - } + let when = actual + .when_constraint_set_assignable_to(self.db, formal, self.inferable) + .limit_to_valid_specializations(self.db); + if when.is_never_satisfied(self.db) + && (formal.has_typevar(self.db) || actual.has_typevar(self.db)) + { + return Err(SpecializationError::NoSolution { + parameter: formal, + argument: actual, + }); } + self.add_type_mappings_from_constraint_set(formal, when, &mut f); } (Type::Callable(formal_callable), _) => { @@ -1948,7 +1913,8 @@ impl<'db> SpecializationBuilder<'db> { self.db, formal_callable, self.inferable, - ); + ) + .limit_to_valid_specializations(self.db); self.add_type_mappings_from_constraint_set(formal, when, &mut f); } else { for actual_signature in &actual_callable.signatures(self.db).overloads { @@ -1957,7 +1923,8 @@ impl<'db> SpecializationBuilder<'db> { self.db, formal_callable, self.inferable, - ); + ) + .limit_to_valid_specializations(self.db); self.add_type_mappings_from_constraint_set(formal, when, &mut f); } } @@ -1974,6 +1941,10 @@ impl<'db> SpecializationBuilder<'db> { #[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum SpecializationError<'db> { + NoSolution { + parameter: Type<'db>, + argument: Type<'db>, + }, MismatchedBound { bound_typevar: BoundTypeVarInstance<'db>, argument: Type<'db>, @@ -1985,15 +1956,17 @@ pub(crate) enum SpecializationError<'db> { } impl<'db> SpecializationError<'db> { - pub(crate) fn bound_typevar(&self) -> BoundTypeVarInstance<'db> { + pub(crate) fn bound_typevar(&self) -> Option> { match self { - Self::MismatchedBound { bound_typevar, .. } => *bound_typevar, - Self::MismatchedConstraint { bound_typevar, .. } => *bound_typevar, + Self::NoSolution { .. } => None, + Self::MismatchedBound { bound_typevar, .. } => Some(*bound_typevar), + Self::MismatchedConstraint { bound_typevar, .. } => Some(*bound_typevar), } } pub(crate) fn argument_type(&self) -> Type<'db> { match self { + Self::NoSolution { argument, .. } => *argument, Self::MismatchedBound { argument, .. } => *argument, Self::MismatchedConstraint { argument, .. } => *argument, } diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 9e674065b9..bec519759d 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -133,14 +133,29 @@ impl<'db> Type<'db> { disjointness_visitor: &IsDisjointVisitor<'db>, ) -> ConstraintSet<'db> { let structurally_satisfied = if let Type::ProtocolInstance(self_protocol) = self { - self_protocol.interface(db).has_relation_to_impl( - db, - protocol.interface(db), - inferable, - relation, - relation_visitor, - disjointness_visitor, - ) + let self_as_nominal = self_protocol.as_nominal_type(); + let other_as_nominal = protocol.as_nominal_type(); + let nominal_match = match self_as_nominal.zip(other_as_nominal) { + Some((self_as_nominal, other_as_nominal)) => self_as_nominal.has_relation_to_impl( + db, + other_as_nominal, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ), + _ => ConstraintSet::from(false), + }; + nominal_match.or(db, || { + self_protocol.interface(db).has_relation_to_impl( + db, + protocol.interface(db), + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + }) } else { protocol .inner