From 39c21d7c6cc8682fa791d93ac23aeac0b32fbca6 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 7 Nov 2025 16:26:30 -0500 Subject: [PATCH] [ty] Generalize some infrastructure around type visitors (#21323) We have lots of `TypeVisitor`s that end up having very similar `visit_type` implementations. This PR consolidates some of the code for these so that there's less repetition and duplication. --- crates/ty_python_semantic/src/types/class.rs | 15 ++----- .../ty_python_semantic/src/types/generics.rs | 17 ++------ .../ty_python_semantic/src/types/visitor.rs | 42 +++++++++++++------ 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index de0b4c180c..862ed4b974 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -30,7 +30,7 @@ use crate::types::member::{Member, class_member}; use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; -use crate::types::visitor::{NonAtomicType, TypeKind, TypeVisitor, walk_non_atomic_type}; +use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DataclassFlags, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, @@ -1437,7 +1437,7 @@ impl<'db> ClassLiteral<'db> { #[derive(Default)] struct CollectTypeVars<'db> { typevars: RefCell>>, - seen_types: RefCell>>, + recursion_guard: TypeCollector<'db>, } impl<'db> TypeVisitor<'db> for CollectTypeVars<'db> { @@ -1454,16 +1454,7 @@ impl<'db> ClassLiteral<'db> { } fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) { - match TypeKind::from(ty) { - TypeKind::Atomic => {} - TypeKind::NonAtomic(non_atomic_type) => { - if !self.seen_types.borrow_mut().insert(non_atomic_type) { - // If we have already seen this type, we can skip it. - return; - } - walk_non_atomic_type(db, non_atomic_type, self); - } - } + walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard); } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 992e664401..555ab47f01 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -14,7 +14,7 @@ use crate::types::constraints::ConstraintSet; use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; -use crate::types::visitor::{NonAtomicType, TypeKind, TypeVisitor, walk_non_atomic_type}; +use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarIdentity, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, @@ -22,7 +22,7 @@ use crate::types::{ TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionType, declaration_type, walk_bound_type_var_type, }; -use crate::{Db, FxIndexSet, FxOrderMap, FxOrderSet}; +use crate::{Db, FxOrderMap, FxOrderSet}; /// Returns an iterator of any generic context introduced by the given scope or any enclosing /// scope. @@ -288,7 +288,7 @@ impl<'db> GenericContext<'db> { #[derive(Default)] struct CollectTypeVars<'db> { typevars: RefCell>>, - seen_types: RefCell>>, + recursion_guard: TypeCollector<'db>, } impl<'db> TypeVisitor<'db> for CollectTypeVars<'db> { @@ -308,16 +308,7 @@ impl<'db> GenericContext<'db> { } fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) { - match TypeKind::from(ty) { - TypeKind::Atomic => {} - TypeKind::NonAtomic(non_atomic_type) => { - if !self.seen_types.borrow_mut().insert(non_atomic_type) { - // If we have already seen this type, we can skip it. - return; - } - walk_non_atomic_type(db, non_atomic_type, self); - } - } + walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard); } } diff --git a/crates/ty_python_semantic/src/types/visitor.rs b/crates/ty_python_semantic/src/types/visitor.rs index d58bf046f1..dd1ddfdfe5 100644 --- a/crates/ty_python_semantic/src/types/visitor.rs +++ b/crates/ty_python_semantic/src/types/visitor.rs @@ -242,6 +242,33 @@ pub(super) fn walk_non_atomic_type<'db, V: TypeVisitor<'db> + ?Sized>( } } +pub(crate) fn walk_type_with_recursion_guard<'db>( + db: &'db dyn Db, + ty: Type<'db>, + visitor: &impl TypeVisitor<'db>, + recursion_guard: &TypeCollector<'db>, +) { + match TypeKind::from(ty) { + TypeKind::Atomic => {} + TypeKind::NonAtomic(non_atomic_type) => { + if recursion_guard.type_was_already_seen(ty) { + // If we have already seen this type, we can skip it. + return; + } + walk_non_atomic_type(db, non_atomic_type, visitor); + } + } +} + +#[derive(Default, Debug)] +pub(crate) struct TypeCollector<'db>(RefCell>>); + +impl<'db> TypeCollector<'db> { + pub(crate) fn type_was_already_seen(&self, ty: Type<'db>) -> bool { + !self.0.borrow_mut().insert(ty) + } +} + /// Return `true` if `ty`, or any of the types contained in `ty`, match the closure passed in. /// /// The function guards against infinite recursion @@ -258,7 +285,7 @@ pub(super) fn any_over_type<'db>( ) -> bool { struct AnyOverTypeVisitor<'db, 'a> { query: &'a dyn Fn(Type<'db>) -> bool, - seen_types: RefCell>>, + recursion_guard: TypeCollector<'db>, found_matching_type: Cell, should_visit_lazy_type_attributes: bool, } @@ -278,22 +305,13 @@ pub(super) fn any_over_type<'db>( if found { return; } - match TypeKind::from(ty) { - TypeKind::Atomic => {} - TypeKind::NonAtomic(non_atomic_type) => { - if !self.seen_types.borrow_mut().insert(non_atomic_type) { - // If we have already seen this type, we can skip it. - return; - } - walk_non_atomic_type(db, non_atomic_type, self); - } - } + walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard); } } let visitor = AnyOverTypeVisitor { query, - seen_types: RefCell::new(FxIndexSet::default()), + recursion_guard: TypeCollector::default(), found_matching_type: Cell::new(false), should_visit_lazy_type_attributes, };