From 411cccb35eafbf91787d36ca6092101328211e2c Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 4 Jul 2025 06:31:44 -0700 Subject: [PATCH] [ty] detect cycles in Type::is_disjoint_from (#19139) --- .../resources/mdtest/protocols.md | 15 +++++ crates/ty_python_semantic/src/types.rs | 63 +++++++++++-------- crates/ty_python_semantic/src/types/cyclic.rs | 44 ++++++++----- .../ty_python_semantic/src/types/instance.rs | 10 ++- .../src/types/protocol_class.rs | 12 +++- .../src/types/subclass_of.rs | 2 +- crates/ty_python_semantic/src/types/tuple.rs | 35 ++++++++--- 7 files changed, 129 insertions(+), 52 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index f507b963a5..fdd848f3dc 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -1862,6 +1862,21 @@ class Bar(Protocol): static_assert(is_equivalent_to(Foo, Bar)) ``` +### Disjointness of recursive protocol and recursive final type + +```py +from typing import Protocol +from ty_extensions import is_disjoint_from, static_assert + +class Proto(Protocol): + x: "Proto" + +class Nominal: + x: "Nominal" + +static_assert(not is_disjoint_from(Proto, Nominal)) +``` + ### Regression test: narrowing with self-referential protocols This snippet caused us to panic on an early version of the implementation for protocols. diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 4ebc60b768..0a1ce8112a 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -19,7 +19,7 @@ use ruff_text_size::{Ranged, TextRange}; use type_ordering::union_or_intersection_elements_ordering; pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; -pub(crate) use self::cyclic::TypeTransformer; +pub(crate) use self::cyclic::{PairVisitor, TypeTransformer}; pub use self::diagnostic::TypeCheckDiagnostics; pub(crate) use self::diagnostic::register_lints; pub(crate) use self::infer::{ @@ -1637,17 +1637,30 @@ impl<'db> Type<'db> { /// Note: This function aims to have no false positives, but might return /// wrong `false` answers in some cases. pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool { + let mut visitor = PairVisitor::new(false); + self.is_disjoint_from_impl(db, other, &mut visitor) + } + + pub(crate) fn is_disjoint_from_impl( + self, + db: &'db dyn Db, + other: Type<'db>, + visitor: &mut PairVisitor<'db>, + ) -> bool { fn any_protocol_members_absent_or_disjoint<'db>( db: &'db dyn Db, protocol: ProtocolInstanceType<'db>, other: Type<'db>, + visitor: &mut PairVisitor<'db>, ) -> bool { protocol.interface(db).members(db).any(|member| { other .member(db, member.name()) .place .ignore_possibly_unbound() - .is_none_or(|attribute_type| member.has_disjoint_type_from(db, attribute_type)) + .is_none_or(|attribute_type| { + member.has_disjoint_type_from(db, attribute_type, visitor) + }) }) } @@ -1681,19 +1694,19 @@ impl<'db> Type<'db> { match typevar.bound_or_constraints(db) { None => false, Some(TypeVarBoundOrConstraints::UpperBound(bound)) => { - bound.is_disjoint_from(db, other) + bound.is_disjoint_from_impl(db, other, visitor) } Some(TypeVarBoundOrConstraints::Constraints(constraints)) => constraints .elements(db) .iter() - .all(|constraint| constraint.is_disjoint_from(db, other)), + .all(|constraint| constraint.is_disjoint_from_impl(db, other, visitor)), } } (Type::Union(union), other) | (other, Type::Union(union)) => union .elements(db) .iter() - .all(|e| e.is_disjoint_from(db, other)), + .all(|e| e.is_disjoint_from_impl(db, other, visitor)), // If we have two intersections, we test the positive elements of each one against the other intersection // Negative elements need a positive element on the other side in order to be disjoint. @@ -1702,11 +1715,11 @@ impl<'db> Type<'db> { self_intersection .positive(db) .iter() - .any(|p| p.is_disjoint_from(db, other)) + .any(|p| p.is_disjoint_from_impl(db, other, visitor)) || other_intersection .positive(db) .iter() - .any(|p: &Type<'_>| p.is_disjoint_from(db, self)) + .any(|p: &Type<'_>| p.is_disjoint_from_impl(db, self, visitor)) } (Type::Intersection(intersection), other) @@ -1714,7 +1727,7 @@ impl<'db> Type<'db> { intersection .positive(db) .iter() - .any(|p| p.is_disjoint_from(db, other)) + .any(|p| p.is_disjoint_from_impl(db, other, visitor)) // A & B & Not[C] is disjoint from C || intersection .negative(db) @@ -1828,17 +1841,17 @@ impl<'db> Type<'db> { } (Type::ProtocolInstance(left), Type::ProtocolInstance(right)) => { - left.is_disjoint_from(db, right) + left.is_disjoint_from_impl(db, right, visitor) } (Type::ProtocolInstance(protocol), Type::SpecialForm(special_form)) | (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => { - any_protocol_members_absent_or_disjoint(db, protocol, special_form.instance_fallback(db)) + any_protocol_members_absent_or_disjoint(db, protocol, special_form.instance_fallback(db), visitor) } (Type::ProtocolInstance(protocol), Type::KnownInstance(known_instance)) | (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => { - any_protocol_members_absent_or_disjoint(db, protocol, known_instance.instance_fallback(db)) + any_protocol_members_absent_or_disjoint(db, protocol, known_instance.instance_fallback(db), visitor) } // The absence of a protocol member on one of these types guarantees @@ -1891,7 +1904,7 @@ impl<'db> Type<'db> { | Type::ModuleLiteral(..) | Type::GenericAlias(..) | Type::IntLiteral(..)), - ) => any_protocol_members_absent_or_disjoint(db, protocol, ty), + ) => any_protocol_members_absent_or_disjoint(db, protocol, ty, visitor), // This is the same as the branch above -- // once guard patterns are stabilised, it could be unified with that branch @@ -1900,7 +1913,7 @@ impl<'db> Type<'db> { | (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) if n.class.is_final(db) => { - any_protocol_members_absent_or_disjoint(db, protocol, nominal) + any_protocol_members_absent_or_disjoint(db, protocol, nominal, visitor) } (Type::ProtocolInstance(protocol), other) @@ -1908,7 +1921,7 @@ impl<'db> Type<'db> { protocol.interface(db).members(db).any(|member| { matches!( other.member(db, member.name()).place, - Place::Type(attribute_type, _) if member.has_disjoint_type_from(db, attribute_type) + Place::Type(attribute_type, _) if member.has_disjoint_type_from(db, attribute_type, visitor) ) }) } @@ -1931,18 +1944,18 @@ impl<'db> Type<'db> { } } - (Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from(db, right), + (Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from_impl(db, right), // for `type[Any]`/`type[Unknown]`/`type[Todo]`, we know the type cannot be any larger than `type`, // so although the type is dynamic we can still determine disjointedness in some situations (Type::SubclassOf(subclass_of_ty), other) | (other, Type::SubclassOf(subclass_of_ty)) => match subclass_of_ty.subclass_of() { SubclassOfInner::Dynamic(_) => { - KnownClass::Type.to_instance(db).is_disjoint_from(db, other) + KnownClass::Type.to_instance(db).is_disjoint_from_impl(db, other, visitor) } SubclassOfInner::Class(class) => class .metaclass_instance_type(db) - .is_disjoint_from(db, other), + .is_disjoint_from_impl(db, other, visitor), }, (Type::SpecialForm(special_form), Type::NominalInstance(instance)) @@ -2027,18 +2040,18 @@ impl<'db> Type<'db> { (Type::BoundMethod(_), other) | (other, Type::BoundMethod(_)) => KnownClass::MethodType .to_instance(db) - .is_disjoint_from(db, other), + .is_disjoint_from_impl(db, other, visitor), (Type::MethodWrapper(_), other) | (other, Type::MethodWrapper(_)) => { KnownClass::MethodWrapperType .to_instance(db) - .is_disjoint_from(db, other) + .is_disjoint_from_impl(db, other, visitor) } (Type::WrapperDescriptor(_), other) | (other, Type::WrapperDescriptor(_)) => { KnownClass::WrapperDescriptorType .to_instance(db) - .is_disjoint_from(db, other) + .is_disjoint_from_impl(db, other, visitor) } (Type::Callable(_) | Type::FunctionLiteral(_), Type::Callable(_)) @@ -2100,15 +2113,15 @@ impl<'db> Type<'db> { (Type::ModuleLiteral(..), other @ Type::NominalInstance(..)) | (other @ Type::NominalInstance(..), Type::ModuleLiteral(..)) => { // Modules *can* actually be instances of `ModuleType` subclasses - other.is_disjoint_from(db, KnownClass::ModuleType.to_instance(db)) + other.is_disjoint_from_impl(db, KnownClass::ModuleType.to_instance(db), visitor) } (Type::NominalInstance(left), Type::NominalInstance(right)) => { - left.is_disjoint_from(db, right) + left.is_disjoint_from_impl(db, right) } (Type::Tuple(tuple), Type::Tuple(other_tuple)) => { - tuple.is_disjoint_from(db, other_tuple) + tuple.is_disjoint_from_impl(db, other_tuple, visitor) } (Type::Tuple(tuple), Type::NominalInstance(instance)) @@ -2121,13 +2134,13 @@ impl<'db> Type<'db> { (Type::PropertyInstance(_), other) | (other, Type::PropertyInstance(_)) => { KnownClass::Property .to_instance(db) - .is_disjoint_from(db, other) + .is_disjoint_from_impl(db, other, visitor) } (Type::BoundSuper(_), Type::BoundSuper(_)) => !self.is_equivalent_to(db, other), (Type::BoundSuper(_), other) | (other, Type::BoundSuper(_)) => KnownClass::Super .to_instance(db) - .is_disjoint_from(db, other), + .is_disjoint_from_impl(db, other, visitor), } } diff --git a/crates/ty_python_semantic/src/types/cyclic.rs b/crates/ty_python_semantic/src/types/cyclic.rs index ac6ba624a9..7922176f2a 100644 --- a/crates/ty_python_semantic/src/types/cyclic.rs +++ b/crates/ty_python_semantic/src/types/cyclic.rs @@ -1,23 +1,39 @@ use crate::FxIndexSet; use crate::types::Type; +use std::cmp::Eq; +use std::hash::Hash; -#[derive(Debug, Default)] -pub(crate) struct TypeTransformer<'db> { - seen: FxIndexSet>, +pub(crate) type TypeTransformer<'db> = CycleDetector, Type<'db>>; + +impl Default for TypeTransformer<'_> { + fn default() -> Self { + // TODO: proper recursive type handling + + // This must be Any, not e.g. a todo type, because Any is the normalized form of the + // dynamic type (that is, todo types are normalized to Any). + CycleDetector::new(Type::any()) + } } -impl<'db> TypeTransformer<'db> { - pub(crate) fn visit( - &mut self, - ty: Type<'db>, - func: impl FnOnce(&mut Self) -> Type<'db>, - ) -> Type<'db> { - if !self.seen.insert(ty) { - // TODO: proper recursive type handling +pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>; - // This must be Any, not e.g. a todo type, because Any is the normalized form of the - // dynamic type (that is, todo types are normalized to Any). - return Type::any(); +#[derive(Debug)] +pub(crate) struct CycleDetector { + seen: FxIndexSet, + fallback: R, +} + +impl CycleDetector { + pub(crate) fn new(fallback: R) -> Self { + CycleDetector { + seen: FxIndexSet::default(), + fallback, + } + } + + pub(crate) fn visit(&mut self, item: T, func: impl FnOnce(&mut Self) -> R) -> R { + if !self.seen.insert(item) { + return self.fallback; } let ret = func(self); self.seen.pop(); diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 9c7c57ec92..33dff7b9c0 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -5,6 +5,7 @@ use std::marker::PhantomData; use super::protocol_class::ProtocolInterface; use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance}; use crate::place::PlaceAndQualifiers; +use crate::types::cyclic::PairVisitor; use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::TupleType; use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance}; @@ -118,7 +119,7 @@ impl<'db> NominalInstanceType<'db> { self.class.is_equivalent_to(db, other.class) } - pub(super) fn is_disjoint_from(self, db: &'db dyn Db, other: Self) -> bool { + pub(super) fn is_disjoint_from_impl(self, db: &'db dyn Db, other: Self) -> bool { !self.class.could_coexist_in_mro_with(db, other.class) } @@ -277,7 +278,12 @@ impl<'db> ProtocolInstanceType<'db> { /// TODO: a protocol `X` is disjoint from a protocol `Y` if `X` and `Y` /// have a member with the same name but disjoint types #[expect(clippy::unused_self)] - pub(super) fn is_disjoint_from(self, _db: &'db dyn Db, _other: Self) -> bool { + pub(super) fn is_disjoint_from_impl( + self, + _db: &'db dyn Db, + _other: Self, + _visitor: &mut PairVisitor<'db>, + ) -> bool { false } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index a21593291d..3d86fabe23 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -11,6 +11,7 @@ use crate::{ types::{ CallableType, ClassBase, ClassLiteral, KnownFunction, PropertyInstanceType, Signature, Type, TypeMapping, TypeQualifiers, TypeRelation, TypeTransformer, TypeVarInstance, + cyclic::PairVisitor, signatures::{Parameter, Parameters}, }, }; @@ -359,11 +360,18 @@ impl<'a, 'db> ProtocolMember<'a, 'db> { } } - pub(super) fn has_disjoint_type_from(&self, db: &'db dyn Db, other: Type<'db>) -> bool { + pub(super) fn has_disjoint_type_from( + &self, + db: &'db dyn Db, + other: Type<'db>, + visitor: &mut PairVisitor<'db>, + ) -> bool { match &self.kind { // TODO: implement disjointness for property/method members as well as attribute members ProtocolMemberKind::Property(_) | ProtocolMemberKind::Method(_) => false, - ProtocolMemberKind::Other(ty) => ty.is_disjoint_from(db, other), + ProtocolMemberKind::Other(ty) => { + visitor.visit((*ty, other), |v| ty.is_disjoint_from_impl(db, other, v)) + } } } diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 474bb5f34e..5b12ae252a 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -170,7 +170,7 @@ impl<'db> SubclassOfType<'db> { /// Return` true` if `self` is a disjoint type from `other`. /// /// See [`Type::is_disjoint_from`] for more details. - pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Self) -> bool { + pub(crate) fn is_disjoint_from_impl(self, db: &'db dyn Db, other: Self) -> bool { match (self.subclass_of, other.subclass_of) { (SubclassOfInner::Dynamic(_), _) | (_, SubclassOfInner::Dynamic(_)) => false, (SubclassOfInner::Class(self_class), SubclassOfInner::Class(other_class)) => { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 2d3caf4e86..58dcb0debb 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -25,7 +25,7 @@ use itertools::{Either, EitherOrBoth, Itertools}; use crate::types::class::{ClassType, KnownClass}; use crate::types::{ Type, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance, TypeVarVariance, - UnionBuilder, UnionType, + UnionBuilder, UnionType, cyclic::PairVisitor, }; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; use crate::{Db, FxOrderSet}; @@ -227,8 +227,14 @@ impl<'db> TupleType<'db> { self.tuple(db).is_equivalent_to(db, other.tuple(db)) } - pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Self) -> bool { - self.tuple(db).is_disjoint_from(db, other.tuple(db)) + pub(crate) fn is_disjoint_from_impl( + self, + db: &'db dyn Db, + other: Self, + visitor: &mut PairVisitor<'db>, + ) -> bool { + self.tuple(db) + .is_disjoint_from_impl(db, other.tuple(db), visitor) } pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool { @@ -1058,7 +1064,12 @@ impl<'db> Tuple> { } } - fn is_disjoint_from(&self, db: &'db dyn Db, other: &Self) -> bool { + fn is_disjoint_from_impl( + &'db self, + db: &'db dyn Db, + other: &'db Self, + visitor: &mut PairVisitor<'db>, + ) -> bool { // Two tuples with an incompatible number of required elements must always be disjoint. let (self_min, self_max) = self.len().size_hint(); let (other_min, other_max) = other.len().size_hint(); @@ -1075,15 +1086,16 @@ impl<'db> Tuple> { db: &'db dyn Db, a: impl IntoIterator>, b: impl IntoIterator>, + visitor: &mut PairVisitor<'db>, ) -> bool { a.into_iter().zip(b).any(|(self_element, other_element)| { - self_element.is_disjoint_from(db, *other_element) + self_element.is_disjoint_from_impl(db, *other_element, visitor) }) } match (self, other) { (Tuple::Fixed(self_tuple), Tuple::Fixed(other_tuple)) => { - if any_disjoint(db, self_tuple.elements(), other_tuple.elements()) { + if any_disjoint(db, self_tuple.elements(), other_tuple.elements(), visitor) { return true; } } @@ -1093,6 +1105,7 @@ impl<'db> Tuple> { db, self_tuple.prefix_elements(), other_tuple.prefix_elements(), + visitor, ) { return true; } @@ -1100,6 +1113,7 @@ impl<'db> Tuple> { db, self_tuple.suffix_elements().rev(), other_tuple.suffix_elements().rev(), + visitor, ) { return true; } @@ -1107,10 +1121,15 @@ impl<'db> Tuple> { (Tuple::Fixed(fixed), Tuple::Variable(variable)) | (Tuple::Variable(variable), Tuple::Fixed(fixed)) => { - if any_disjoint(db, fixed.elements(), variable.prefix_elements()) { + if any_disjoint(db, fixed.elements(), variable.prefix_elements(), visitor) { return true; } - if any_disjoint(db, fixed.elements().rev(), variable.suffix_elements().rev()) { + if any_disjoint( + db, + fixed.elements().rev(), + variable.suffix_elements().rev(), + visitor, + ) { return true; } }