From e6a798b9625b9456dfd5b1d109830981b33ad954 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 3 May 2025 16:43:37 +0100 Subject: [PATCH] [red-knot] Recurse into the types of protocol members when normalizing a protocol's interface (#17808) ## Summary Currently red-knot does not understand `Foo` and `Bar` here as being equivalent: ```py from typing import Protocol class A: ... class B: ... class C: ... class Foo(Protocol): x: A | B | C class Bar(Protocol): x: B | A | C ``` Nor does it understand `A | B | Foo` as being equivalent to `Bar | B | A`. This PR fixes that. ## Test Plan new mdtest assertions added that fail on `main` --- .../resources/mdtest/protocols.md | 16 ++++++ .../src/types/instance.rs | 52 ++++++++++++++----- .../src/types/protocol_class.rs | 18 +++++++ 3 files changed, 72 insertions(+), 14 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/protocols.md b/crates/red_knot_python_semantic/resources/mdtest/protocols.md index 1b33205e47..cc3f8143f8 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/protocols.md +++ b/crates/red_knot_python_semantic/resources/mdtest/protocols.md @@ -816,6 +816,22 @@ class B: ... static_assert(is_equivalent_to(A | HasX | B | HasY, B | AlsoHasY | AlsoHasX | A)) ``` +Protocols are considered equivalent if their members are equivalent, even if those members are +differently ordered unions: + +```py +class C: ... + +class UnionProto1(Protocol): + x: A | B | C + +class UnionProto2(Protocol): + x: C | A | B + +static_assert(is_equivalent_to(UnionProto1, UnionProto2)) +static_assert(is_equivalent_to(UnionProto1 | A | B, B | UnionProto2 | A)) +``` + ## Intersections of protocols An intersection of two protocol types `X` and `Y` is equivalent to a protocol type `Z` that inherits diff --git a/crates/red_knot_python_semantic/src/types/instance.rs b/crates/red_knot_python_semantic/src/types/instance.rs index a6e3f805d1..762dc9a341 100644 --- a/crates/red_knot_python_semantic/src/types/instance.rs +++ b/crates/red_knot_python_semantic/src/types/instance.rs @@ -5,6 +5,8 @@ use super::{ClassType, KnownClass, SubclassOfType, Type}; use crate::symbol::{Symbol, SymbolAndQualifiers}; use crate::Db; +pub(super) use synthesized_protocol::SynthesizedProtocolType; + impl<'db> Type<'db> { pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self { if class.class_literal(db).0.is_protocol(db) { @@ -164,7 +166,7 @@ impl<'db> ProtocolInstanceType<'db> { } match self.0 { Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized( - SynthesizedProtocolType::new(db, self.0.interface(db)), + SynthesizedProtocolType::new(db, self.0.interface(db).clone()), ))), Protocol::Synthesized(_) => Type::ProtocolInstance(self), } @@ -237,9 +239,7 @@ impl<'db> ProtocolInstanceType<'db> { /// An enumeration of the two kinds of protocol types: those that originate from a class /// definition in source code, and those that are synthesized from a set of members. -#[derive( - Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, salsa::Supertype, PartialOrd, Ord, -)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)] pub(super) enum Protocol<'db> { FromClass(ClassType<'db>), Synthesized(SynthesizedProtocolType<'db>), @@ -260,14 +260,38 @@ impl<'db> Protocol<'db> { } } -/// A "synthesized" protocol type that is dissociated from a class definition in source code. -/// -/// Two synthesized protocol types with the same members will share the same Salsa ID, -/// making them easy to compare for equivalence. A synthesized protocol type is therefore -/// returned by [`ProtocolInstanceType::normalized`] so that two protocols with the same members -/// will be understood as equivalent even in the context of differently ordered unions or intersections. -#[salsa::interned(debug)] -pub(super) struct SynthesizedProtocolType<'db> { - #[return_ref] - pub(super) interface: ProtocolInterface<'db>, +mod synthesized_protocol { + use crate::db::Db; + use crate::types::protocol_class::ProtocolInterface; + + /// A "synthesized" protocol type that is dissociated from a class definition in source code. + /// + /// Two synthesized protocol types with the same members will share the same Salsa ID, + /// making them easy to compare for equivalence. A synthesized protocol type is therefore + /// returned by [`super::ProtocolInstanceType::normalized`] so that two protocols with the same members + /// will be understood as equivalent even in the context of differently ordered unions or intersections. + /// + /// The constructor method of this type maintains the invariant that a synthesized protocol type + /// is always constructed from a *normalized* protocol interface. + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)] + pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>); + + impl<'db> SynthesizedProtocolType<'db> { + pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self { + Self(SynthesizedProtocolTypeInner::new( + db, + interface.normalized(db), + )) + } + + pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> { + self.0.interface(db) + } + } + + #[salsa::interned(debug)] + struct SynthesizedProtocolTypeInner<'db> { + #[return_ref] + interface: ProtocolInterface<'db>, + } } diff --git a/crates/red_knot_python_semantic/src/types/protocol_class.rs b/crates/red_knot_python_semantic/src/types/protocol_class.rs index 6706dea53c..743ca2f7a7 100644 --- a/crates/red_knot_python_semantic/src/types/protocol_class.rs +++ b/crates/red_knot_python_semantic/src/types/protocol_class.rs @@ -96,6 +96,15 @@ impl<'db> ProtocolInterface<'db> { pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool { self.members().any(|member| member.ty.contains_todo(db)) } + + pub(super) fn normalized(self, db: &'db dyn Db) -> Self { + Self( + self.0 + .into_iter() + .map(|(name, data)| (name, data.normalized(db))) + .collect(), + ) + } } #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)] @@ -104,6 +113,15 @@ struct ProtocolMemberData<'db> { qualifiers: TypeQualifiers, } +impl<'db> ProtocolMemberData<'db> { + fn normalized(self, db: &'db dyn Db) -> Self { + Self { + ty: self.ty.normalized(db), + qualifiers: self.qualifiers, + } + } +} + /// A single member of a protocol interface. #[derive(Debug, PartialEq, Eq)] pub(super) struct ProtocolMember<'a, 'db> {