[red-knot] Recurse into the types of protocol members when normalizing a protocol's interface (#17808)

## Summary

Currently red-knot does not understand `Foo` and `Bar` here as being
equivalent:

```py
from typing import Protocol

class A: ...
class B: ...
class C: ...

class Foo(Protocol):
    x: A | B | C

class Bar(Protocol):
    x: B | A | C
```

Nor does it understand `A | B | Foo` as being equivalent to `Bar | B |
A`. This PR fixes that.

## Test Plan

new mdtest assertions added that fail on `main`
This commit is contained in:
Alex Waygood 2025-05-03 16:43:37 +01:00 committed by GitHub
parent 52b0470870
commit e6a798b962
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 14 deletions

View File

@ -816,6 +816,22 @@ class B: ...
static_assert(is_equivalent_to(A | HasX | B | HasY, B | AlsoHasY | AlsoHasX | A)) static_assert(is_equivalent_to(A | HasX | B | HasY, B | AlsoHasY | AlsoHasX | A))
``` ```
Protocols are considered equivalent if their members are equivalent, even if those members are
differently ordered unions:
```py
class C: ...
class UnionProto1(Protocol):
x: A | B | C
class UnionProto2(Protocol):
x: C | A | B
static_assert(is_equivalent_to(UnionProto1, UnionProto2))
static_assert(is_equivalent_to(UnionProto1 | A | B, B | UnionProto2 | A))
```
## Intersections of protocols ## Intersections of protocols
An intersection of two protocol types `X` and `Y` is equivalent to a protocol type `Z` that inherits An intersection of two protocol types `X` and `Y` is equivalent to a protocol type `Z` that inherits

View File

@ -5,6 +5,8 @@ use super::{ClassType, KnownClass, SubclassOfType, Type};
use crate::symbol::{Symbol, SymbolAndQualifiers}; use crate::symbol::{Symbol, SymbolAndQualifiers};
use crate::Db; use crate::Db;
pub(super) use synthesized_protocol::SynthesizedProtocolType;
impl<'db> Type<'db> { impl<'db> Type<'db> {
pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self { pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self {
if class.class_literal(db).0.is_protocol(db) { if class.class_literal(db).0.is_protocol(db) {
@ -164,7 +166,7 @@ impl<'db> ProtocolInstanceType<'db> {
} }
match self.0 { match self.0 {
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized( Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
SynthesizedProtocolType::new(db, self.0.interface(db)), SynthesizedProtocolType::new(db, self.0.interface(db).clone()),
))), ))),
Protocol::Synthesized(_) => Type::ProtocolInstance(self), Protocol::Synthesized(_) => Type::ProtocolInstance(self),
} }
@ -237,9 +239,7 @@ impl<'db> ProtocolInstanceType<'db> {
/// An enumeration of the two kinds of protocol types: those that originate from a class /// An enumeration of the two kinds of protocol types: those that originate from a class
/// definition in source code, and those that are synthesized from a set of members. /// definition in source code, and those that are synthesized from a set of members.
#[derive( #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, salsa::Supertype, PartialOrd, Ord,
)]
pub(super) enum Protocol<'db> { pub(super) enum Protocol<'db> {
FromClass(ClassType<'db>), FromClass(ClassType<'db>),
Synthesized(SynthesizedProtocolType<'db>), Synthesized(SynthesizedProtocolType<'db>),
@ -260,14 +260,38 @@ impl<'db> Protocol<'db> {
} }
} }
/// A "synthesized" protocol type that is dissociated from a class definition in source code. mod synthesized_protocol {
/// use crate::db::Db;
/// Two synthesized protocol types with the same members will share the same Salsa ID, use crate::types::protocol_class::ProtocolInterface;
/// making them easy to compare for equivalence. A synthesized protocol type is therefore
/// returned by [`ProtocolInstanceType::normalized`] so that two protocols with the same members /// A "synthesized" protocol type that is dissociated from a class definition in source code.
/// will be understood as equivalent even in the context of differently ordered unions or intersections. ///
#[salsa::interned(debug)] /// Two synthesized protocol types with the same members will share the same Salsa ID,
pub(super) struct SynthesizedProtocolType<'db> { /// making them easy to compare for equivalence. A synthesized protocol type is therefore
#[return_ref] /// returned by [`super::ProtocolInstanceType::normalized`] so that two protocols with the same members
pub(super) interface: ProtocolInterface<'db>, /// will be understood as equivalent even in the context of differently ordered unions or intersections.
///
/// The constructor method of this type maintains the invariant that a synthesized protocol type
/// is always constructed from a *normalized* protocol interface.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>);
impl<'db> SynthesizedProtocolType<'db> {
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
Self(SynthesizedProtocolTypeInner::new(
db,
interface.normalized(db),
))
}
pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
self.0.interface(db)
}
}
#[salsa::interned(debug)]
struct SynthesizedProtocolTypeInner<'db> {
#[return_ref]
interface: ProtocolInterface<'db>,
}
} }

View File

@ -96,6 +96,15 @@ impl<'db> ProtocolInterface<'db> {
pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool { pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool {
self.members().any(|member| member.ty.contains_todo(db)) self.members().any(|member| member.ty.contains_todo(db))
} }
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
Self(
self.0
.into_iter()
.map(|(name, data)| (name, data.normalized(db)))
.collect(),
)
}
} }
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)] #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
@ -104,6 +113,15 @@ struct ProtocolMemberData<'db> {
qualifiers: TypeQualifiers, qualifiers: TypeQualifiers,
} }
impl<'db> ProtocolMemberData<'db> {
fn normalized(self, db: &'db dyn Db) -> Self {
Self {
ty: self.ty.normalized(db),
qualifiers: self.qualifiers,
}
}
}
/// A single member of a protocol interface. /// A single member of a protocol interface.
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub(super) struct ProtocolMember<'a, 'db> { pub(super) struct ProtocolMember<'a, 'db> {