diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index e68cc538ee..5c7899c7a4 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -150,6 +150,14 @@ class AlsoInvalid(MyProtocol, OtherProtocol, NotAProtocol, Protocol): ... # revealed: tuple[, , , , typing.Protocol, typing.Generic, ] reveal_type(AlsoInvalid.__mro__) + +class NotAGenericProtocol[T]: ... + +# error: [invalid-protocol] "Protocol class `StillInvalid` cannot inherit from non-protocol class `NotAGenericProtocol`" +class StillInvalid(NotAGenericProtocol[int], Protocol): ... + +# revealed: tuple[, , typing.Protocol, typing.Generic, ] +reveal_type(StillInvalid.__mro__) ``` But two exceptions to this rule are `object` and `Generic`: diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 018eccd71c..2a6ebc63ae 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -1117,13 +1117,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // - Check for inheritance from a `@final` classes // - If the class is a protocol class: check for inheritance from a non-protocol class for (i, base_class) in class.explicit_bases(self.db()).iter().enumerate() { - if let Some((class, solid_base)) = base_class - .to_class_type(self.db()) - .and_then(|class| Some((class, class.nearest_solid_base(self.db())?))) - { - solid_bases.insert(solid_base, i, class.class_literal(self.db()).0); - } - let base_class = match base_class { Type::SpecialForm(SpecialFormType::Generic) => { if let Some(builder) = self @@ -1155,13 +1148,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); continue; } - Type::ClassLiteral(class) => class, - // dynamic/unknown bases are never `@final` + Type::ClassLiteral(class) => ClassType::NonGeneric(*class), + Type::GenericAlias(class) => ClassType::Generic(*class), _ => continue, }; + if let Some(solid_base) = base_class.nearest_solid_base(self.db()) { + solid_bases.insert(solid_base, i, base_class.class_literal(self.db()).0); + } + if is_protocol - && !(base_class.is_protocol(self.db()) + && !(base_class.class_literal(self.db()).0.is_protocol(self.db()) || base_class.is_known(self.db(), KnownClass::Object)) { if let Some(builder) = self