diff --git a/crates/ruff_macros/src/rule_code_prefix.rs b/crates/ruff_macros/src/rule_code_prefix.rs index aba584e43d..6ac0af1f93 100644 --- a/crates/ruff_macros/src/rule_code_prefix.rs +++ b/crates/ruff_macros/src/rule_code_prefix.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use proc_macro2::Span; use quote::quote; @@ -11,10 +11,10 @@ pub fn expand<'a>( variant_name: impl Fn(&str) -> &'a Ident, ) -> proc_macro2::TokenStream { // Build up a map from prefix to matching RuleCodes. - let mut prefix_to_codes: BTreeMap>> = - BTreeMap::default(); + let mut prefix_to_codes: BTreeMap> = BTreeMap::default(); + let mut attributes: BTreeMap = BTreeMap::default(); - let mut pl_codes = BTreeMap::new(); + let mut pl_codes = BTreeSet::new(); for (variant, attr) in variants { let code_str = variant.to_string(); @@ -28,26 +28,32 @@ pub fn expand<'a>( prefix_to_codes .entry(prefix) .or_default() - .entry(code_str.clone()) - .or_insert_with(|| attr.clone()); + .insert(code_str.clone()); } if code_str.starts_with("PL") { - pl_codes.insert(code_str, attr.clone()); + pl_codes.insert(code_str.clone()); } + attributes.insert(code_str, attr); } prefix_to_codes.insert("PL".to_string(), pl_codes); let prefix_variants = prefix_to_codes.iter().map(|(prefix, codes)| { let prefix = Ident::new(prefix, Span::call_site()); - let attr = if_all_same(codes.values().cloned()).unwrap_or_default(); + let attrs = attributes_for_prefix(codes, &attributes); quote! { - #(#attr)* + #attrs #prefix } }); - let prefix_impl = generate_impls(rule_type, prefix_ident, &prefix_to_codes, variant_name); + let prefix_impl = generate_impls( + rule_type, + prefix_ident, + &prefix_to_codes, + variant_name, + &attributes, + ); quote! { #[derive( @@ -76,21 +82,23 @@ pub fn expand<'a>( fn generate_impls<'a>( rule_type: &Ident, prefix_ident: &Ident, - prefix_to_codes: &BTreeMap>>, + prefix_to_codes: &BTreeMap>, variant_name: impl Fn(&str) -> &'a Ident, + attributes: &BTreeMap, ) -> proc_macro2::TokenStream { let into_iter_match_arms = prefix_to_codes.iter().map(|(prefix_str, codes)| { let prefix = Ident::new(prefix_str, Span::call_site()); - let attr = if_all_same(codes.values().cloned()).unwrap_or_default(); - let codes = codes.iter().map(|(code, attr)| { + let attrs = attributes_for_prefix(codes, attributes); + let codes = codes.iter().map(|code| { let rule_variant = variant_name(code); + let attrs = attributes[code]; quote! { - #(#attr)* + #(#attrs)* #rule_type::#rule_variant } }); quote! { - #(#attr)* + #attrs #prefix_ident::#prefix => vec![#(#codes),*].into_iter(), } }); @@ -110,9 +118,9 @@ fn generate_impls<'a>( 5 => quote! { Specificity::Code5Chars }, _ => panic!("Invalid prefix: {prefix}"), }; - let attr = if_all_same(codes.values().cloned()).unwrap_or_default(); + let attrs = attributes_for_prefix(codes, attributes); quote! { - #(#attr)* + #attrs #prefix_ident::#prefix => #suffix_len, } }); @@ -143,6 +151,16 @@ fn generate_impls<'a>( } } +fn attributes_for_prefix( + codes: &BTreeSet, + attributes: &BTreeMap, +) -> proc_macro2::TokenStream { + match if_all_same(codes.iter().map(|code| attributes[code])) { + Some(attr) => quote!(#(#attr)*), + None => quote!(), + } +} + /// If all values in an iterator are the same, return that value. Otherwise, /// return `None`. fn if_all_same(iter: impl Iterator) -> Option {