From 65021fcee918b9d426888bbe128fada2bff72083 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Tue, 23 Dec 2025 03:24:01 -0500 Subject: [PATCH] [ty] Support type inference between protocol instances (#22120) --- .../mdtest/generics/legacy/functions.md | 13 +++++++++++++ .../mdtest/generics/pep695/functions.md | 13 +++++++++++++ crates/ty_python_semantic/src/types/generics.rs | 16 ++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index c674f7a9a1..7f786b8765 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -80,6 +80,7 @@ class CanIndex(Protocol[T]): def __getitem__(self, index: int, /) -> T: ... class ExplicitlyImplements(CanIndex[T]): ... +class SubProtocol(CanIndex[T], Protocol): ... def takes_in_list(x: list[T]) -> list[T]: return x @@ -103,6 +104,18 @@ def deep_explicit(x: ExplicitlyImplements[str]) -> None: def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None: reveal_type(takes_in_protocol(x)) # revealed: set[str] +def deep_subprotocol(x: SubProtocol[str]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: str + +def deeper_subprotocol(x: SubProtocol[set[str]]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: set[str] + +def itself(x: CanIndex[str]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: str + +def deep_itself(x: CanIndex[set[str]]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: set[str] + def takes_in_type(x: type[T]) -> type[T]: return x diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 3cdebe848e..b6732fb6f0 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -75,6 +75,7 @@ class CanIndex(Protocol[S]): def __getitem__(self, index: int, /) -> S: ... class ExplicitlyImplements[T](CanIndex[T]): ... +class SubProtocol[T](CanIndex[T], Protocol): ... def takes_in_list[T](x: list[T]) -> list[T]: return x @@ -98,6 +99,18 @@ def deep_explicit(x: ExplicitlyImplements[str]) -> None: def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None: reveal_type(takes_in_protocol(x)) # revealed: set[str] +def deep_subprotocol(x: SubProtocol[str]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: str + +def deeper_subprotocol(x: SubProtocol[set[str]]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: set[str] + +def itself(x: CanIndex[str]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: str + +def deep_itself(x: CanIndex[set[str]]) -> None: + reveal_type(takes_in_protocol(x)) # revealed: set[str] + def takes_in_type[T](x: type[T]) -> type[T]: return x diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 36e89553a0..3e6364e6d3 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1866,6 +1866,22 @@ impl<'db> SpecializationBuilder<'db> { } } + (formal, Type::ProtocolInstance(actual_protocol)) => { + // TODO: This will only handle protocol classes that explicit inherit + // from other generic protocol classes by listing it as a base class. + // To handle classes that implicitly implement a generic protocol, we + // will need to check the types of the protocol members to be able to + // infer the specialization of the protocol that the class implements. + if let Some(actual_nominal) = actual_protocol.as_nominal_type() { + return self.infer_map_impl( + formal, + Type::NominalInstance(actual_nominal), + polarity, + f, + ); + } + } + (formal, Type::NominalInstance(actual_nominal)) => { // Special case: `formal` and `actual` are both tuples. if let (Some(formal_tuple), Some(actual_tuple)) = (