Skip to main content

miniextendr_macros/
factor_derive.rs

1//! # `#[derive(RFactor)]` - Enum ↔ R Factor Support
2//!
3//! This module implements the `#[derive(RFactor)]` macro which generates
4//! the `RFactor` trait implementation for C-style enums, enabling automatic
5//! conversion between Rust enums and R factors.
6//!
7//! ## Usage
8//!
9//! ```ignore
10//! #[derive(Copy, Clone, RFactor)]
11//! enum Color {
12//!     Red,
13//!     Green,
14//!     Blue,
15//! }
16//!
17//! // Generates impl RFactor for Color, IntoR for Color, TryFromSexp for Color,
18//! // and Vec/Option variants.
19//! ```
20//!
21//! ## Attributes
22//!
23//! - `#[r_factor(rename = "name")]` - Rename a variant's level string
24//! - `#[r_factor(rename_all = "snake_case")]` - Rename all variants (snake_case, kebab-case, lower, upper)
25//!
26//! ```ignore
27//! #[derive(Copy, Clone, RFactor)]
28//! #[r_factor(rename_all = "snake_case")]
29//! enum Status {
30//!     InProgress,  // level: "in_progress"
31//!     #[r_factor(rename = "done")]
32//!     Completed,   // level: "done"
33//! }
34//! ```
35//!
36//! ## Interaction Factors
37//!
38//! For enums wrapping another RFactor type (like R's `interaction()`):
39//!
40//! ```ignore
41//! #[derive(Copy, Clone, RFactor)]
42//! enum Supplement { OJ, VC }
43//!
44//! #[derive(Copy, Clone, RFactor)]
45//! #[r_factor(interaction = ["OJ", "VC"])]  // inner type's levels
46//! enum SpeciesSupplement {
47//!     Setosa(Supplement),
48//!     Versicolor(Supplement),
49//!     Virginica(Supplement),
50//! }
51//! // Levels: ["Setosa.OJ", "Setosa.VC", "Versicolor.OJ", ...]
52//! ```
53
54use proc_macro2::TokenStream;
55use quote::quote;
56use syn::{Data, DeriveInput, Fields, Type};
57
58use crate::naming::apply_rename_all;
59
60/// Parsed `#[r_factor(...)]` attributes from an enum or variant.
61#[derive(Default)]
62struct RFactorAttrs {
63    /// Per-variant rename: `#[r_factor(rename = "custom_name")]`.
64    rename: Option<String>,
65    /// Enum-level rename-all: `#[r_factor(rename_all = "snake_case")]`.
66    /// Applied to all variants that don't have an explicit `rename`.
67    rename_all: Option<String>,
68    /// Inner type's level names for interaction factors:
69    /// `#[r_factor(interaction = ["A", "B"])]`.
70    /// When present, triggers interaction factor codegen instead of simple factor.
71    interaction: Option<Vec<String>>,
72    /// Separator between outer and inner level names in interaction factors.
73    /// Defaults to `"."`. Specified via `#[r_factor(sep = "_")]`.
74    sep: Option<String>,
75}
76
77/// Parse `#[r_factor(...)]` attributes from a list of `syn::Attribute`.
78///
79/// Extracts `rename`, `rename_all`, `interaction`, and `sep` keys.
80/// Returns `Err` for unknown attribute keys.
81fn parse_r_factor_attrs(attrs: &[syn::Attribute]) -> syn::Result<RFactorAttrs> {
82    let mut result = RFactorAttrs::default();
83
84    for attr in attrs {
85        if attr.path().is_ident("r_factor") {
86            attr.parse_nested_meta(|meta| {
87                if meta.path.is_ident("rename") {
88                    let value: syn::LitStr = meta.value()?.parse()?;
89                    result.rename = Some(value.value());
90                } else if meta.path.is_ident("rename_all") {
91                    let value: syn::LitStr = meta.value()?.parse()?;
92                    result.rename_all = Some(value.value());
93                } else if meta.path.is_ident("interaction") {
94                    // Parse as array: interaction = ["A", "B", "C"]
95                    let _eq: syn::Token![=] = meta.input.parse()?;
96                    let content;
97                    syn::bracketed!(content in meta.input);
98                    let levels: syn::punctuated::Punctuated<syn::LitStr, syn::Token![,]> =
99                        content.parse_terminated(|input| input.parse(), syn::Token![,])?;
100                    result.interaction = Some(levels.iter().map(|s| s.value()).collect());
101                } else if meta.path.is_ident("sep") {
102                    let value: syn::LitStr = meta.value()?.parse()?;
103                    result.sep = Some(value.value());
104                } else {
105                    return Err(meta.error("unknown r_factor attribute"));
106                }
107                Ok(())
108            })?;
109        }
110    }
111
112    Ok(result)
113}
114
115/// Main entry point for `#[derive(RFactor)]`.
116///
117/// Dispatches to either [`derive_simple_factor`] (C-style unit variants) or
118/// [`derive_interaction_factor`] (tuple variants wrapping an inner RFactor type),
119/// based on whether `#[r_factor(interaction = [...])]` is present.
120///
121/// Generates:
122/// - `impl MatchArg` (string choices for `match.arg`)
123/// - `impl RFactor` (1-based level index conversion)
124/// - `impl IntoR` (Rust enum -> R factor SEXP)
125/// - `impl TryFromSexp` (R factor SEXP -> Rust enum)
126///
127/// Returns `Err` for structs, unions, or invalid attribute combinations.
128pub fn derive_r_factor(input: DeriveInput) -> syn::Result<TokenStream> {
129    let name = &input.ident;
130    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
131
132    // Parse enum-level attributes
133    let attrs = parse_r_factor_attrs(&input.attrs)?;
134
135    // Get enum variants
136    let variants = match &input.data {
137        Data::Enum(data) => &data.variants,
138        Data::Struct(_) => {
139            return Err(syn::Error::new_spanned(
140                &input,
141                "#[derive(RFactor)] can only be applied to enums",
142            ));
143        }
144        Data::Union(_) => {
145            return Err(syn::Error::new_spanned(
146                &input,
147                "#[derive(RFactor)] can only be applied to enums",
148            ));
149        }
150    };
151
152    // Branch based on whether this is an interaction factor
153    if let Some(inner_levels) = &attrs.interaction {
154        derive_interaction_factor(
155            name,
156            &impl_generics,
157            &ty_generics,
158            where_clause,
159            variants,
160            inner_levels,
161            attrs.sep.as_deref().unwrap_or("."),
162            attrs.rename_all.as_deref(),
163        )
164    } else {
165        derive_simple_factor(
166            name,
167            &impl_generics,
168            &ty_generics,
169            where_clause,
170            variants,
171            attrs.rename_all.as_deref(),
172        )
173    }
174}
175
176/// Generate `RFactor`, `MatchArg`, `IntoR`, and `TryFromSexp` impls for simple
177/// (unit variant) enums.
178///
179/// Each variant maps to a 1-based level index and a level name string.
180/// Level names are determined by variant ident, optionally transformed by
181/// `rename_all` or overridden by per-variant `#[r_factor(rename = "...")]`.
182///
183/// Uses a `OnceLock`-cached levels SEXP for efficient repeated conversion.
184///
185/// Returns `Err` if any variant has fields (only C-style enums are supported
186/// in the simple path).
187fn derive_simple_factor(
188    name: &syn::Ident,
189    impl_generics: &syn::ImplGenerics,
190    ty_generics: &syn::TypeGenerics,
191    where_clause: Option<&syn::WhereClause>,
192    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
193    rename_all: Option<&str>,
194) -> syn::Result<TokenStream> {
195    let mut level_names = Vec::new();
196    let mut variant_idents = Vec::new();
197
198    for variant in variants {
199        // Check for fields (only allow unit variants)
200        if !matches!(variant.fields, Fields::Unit) {
201            return Err(syn::Error::new_spanned(
202                variant,
203                "#[derive(RFactor)] only supports fieldless (C-style) enum variants \
204                 (use #[r_factor(interaction = [...])] for tuple variants)",
205            ));
206        }
207
208        // Parse variant-level attributes
209        let var_attrs = parse_r_factor_attrs(&variant.attrs)?;
210
211        // Determine level name
212        let level_name = if let Some(r) = var_attrs.rename {
213            r
214        } else {
215            apply_rename_all(&variant.ident.to_string(), rename_all)
216        };
217
218        level_names.push(level_name);
219        variant_idents.push(&variant.ident);
220    }
221
222    // Generate indices (1-based for R)
223    let indices: Vec<i32> = (1..=variant_idents.len() as i32).collect();
224    let level_name_strs: Vec<&str> = level_names.iter().map(|s| s.as_str()).collect();
225
226    Ok(quote! {
227        impl #impl_generics ::miniextendr_api::match_arg::MatchArg for #name #ty_generics #where_clause {
228            const CHOICES: &'static [&'static str] = &[#(#level_name_strs),*];
229
230            fn from_choice(choice: &str) -> Option<Self> {
231                match choice {
232                    #(#level_name_strs => Some(Self::#variant_idents),)*
233                    _ => None,
234                }
235            }
236
237            fn to_choice(self) -> &'static str {
238                match self {
239                    #(Self::#variant_idents => #level_name_strs,)*
240                }
241            }
242        }
243
244        impl #impl_generics ::miniextendr_api::RFactor for #name #ty_generics #where_clause {
245            fn to_level_index(self) -> i32 {
246                match self {
247                    #(Self::#variant_idents => #indices,)*
248                }
249            }
250
251            fn from_level_index(idx: i32) -> Option<Self> {
252                match idx {
253                    #(#indices => Some(Self::#variant_idents),)*
254                    _ => None,
255                }
256            }
257        }
258
259        impl #impl_generics ::miniextendr_api::IntoR for #name #ty_generics #where_clause {
260            type Error = std::convert::Infallible;
261
262            fn try_into_sexp(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
263                Ok(self.into_sexp())
264            }
265
266            unsafe fn try_into_sexp_unchecked(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
267                self.try_into_sexp()
268            }
269
270            fn into_sexp(self) -> ::miniextendr_api::ffi::SEXP {
271                static LEVELS_CACHE: ::std::sync::OnceLock<::miniextendr_api::ffi::SEXP> =
272                    ::std::sync::OnceLock::new();
273                let levels = *LEVELS_CACHE.get_or_init(|| {
274                    ::miniextendr_api::build_levels_sexp_cached(
275                        <Self as ::miniextendr_api::match_arg::MatchArg>::CHOICES
276                    )
277                });
278                ::miniextendr_api::build_factor(&[<Self as ::miniextendr_api::RFactor>::to_level_index(self)], levels)
279            }
280        }
281
282        impl #impl_generics ::miniextendr_api::TryFromSexp for #name #ty_generics #where_clause {
283            type Error = ::miniextendr_api::SexpError;
284
285            fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Result<Self, Self::Error> {
286                ::miniextendr_api::factor_from_sexp(sexp)
287            }
288        }
289    })
290}
291
292/// Generate `RFactor`, `MatchArg`, `IntoR`, and `TryFromSexp` impls for interaction
293/// (tuple variant) enums.
294///
295/// Interaction factors combine an outer enum (the variant) with an inner `RFactor`
296/// type, producing combined level names like `"Outer.Inner"`. The level order is
297/// outer-varies-slowest (matches R's `interaction(..., lex.order = TRUE)`).
298///
299/// Generates a compile-time assertion that the specified `inner_levels` match the
300/// inner type's `MatchArg::CHOICES`, catching mismatches early.
301///
302/// All variants must be single-field tuples wrapping the same inner type.
303///
304/// # Arguments
305///
306/// * `inner_levels` - The expected level strings of the inner type
307///   (from `#[r_factor(interaction = [...])]`)
308/// * `sep` - Separator between outer and inner level names (default `"."`)
309/// * `rename_all` - Optional rename transformation for outer variant names
310#[allow(clippy::too_many_arguments)] // generics plumbing, single call site
311fn derive_interaction_factor(
312    name: &syn::Ident,
313    impl_generics: &syn::ImplGenerics,
314    ty_generics: &syn::TypeGenerics,
315    where_clause: Option<&syn::WhereClause>,
316    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
317    inner_levels: &[String],
318    sep: &str,
319    rename_all: Option<&str>,
320) -> syn::Result<TokenStream> {
321    let mut outer_names = Vec::new();
322    let mut variant_idents = Vec::new();
323    let mut inner_type: Option<Type> = None;
324
325    for variant in variants {
326        // Require single-field tuple variants
327        let field_ty = match &variant.fields {
328            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
329                fields.unnamed.first().unwrap().ty.clone()
330            }
331            _ => {
332                return Err(syn::Error::new_spanned(
333                    variant,
334                    "interaction factors require single-field tuple variants: Variant(InnerType)",
335                ));
336            }
337        };
338
339        // All variants must have the same inner type
340        if let Some(ref existing) = inner_type {
341            if field_ty != *existing {
342                return Err(syn::Error::new_spanned(
343                    &variant.fields,
344                    "all variants must have the same inner type",
345                ));
346            }
347        } else {
348            inner_type = Some(field_ty);
349        }
350
351        // Parse variant-level attributes for rename
352        let var_attrs = parse_r_factor_attrs(&variant.attrs)?;
353        let outer_name = if let Some(r) = var_attrs.rename {
354            r
355        } else {
356            apply_rename_all(&variant.ident.to_string(), rename_all)
357        };
358
359        outer_names.push(outer_name);
360        variant_idents.push(&variant.ident);
361    }
362
363    let inner_type = inner_type.ok_or_else(|| {
364        syn::Error::new_spanned(name, "interaction factor must have at least one variant")
365    })?;
366
367    let n_outer = outer_names.len();
368    let n_inner = inner_levels.len();
369
370    // Generate combined levels at compile time using concat!
371    // Order: outer varies slowest (lex.order style)
372    // [Outer1.Inner1, Outer1.Inner2, ..., Outer2.Inner1, ...]
373    let mut combined_levels = Vec::new();
374    for outer_name in &outer_names {
375        for inner_name in inner_levels {
376            // Use concat! for compile-time string concatenation
377            let combined = format!("{}{}{}", outer_name, sep, inner_name);
378            combined_levels.push(combined);
379        }
380    }
381    let combined_level_strs: Vec<&str> = combined_levels.iter().map(|s| s.as_str()).collect();
382
383    // Generate to_level_index match arms
384    // Index = outer_idx_0 * n_inner + inner_idx_0 + 1
385    let n_inner_lit = n_inner as i32;
386    let to_index_arms: Vec<_> = variant_idents
387        .iter()
388        .enumerate()
389        .map(|(outer_idx, var_ident)| {
390            let outer_idx_lit = outer_idx as i32;
391            quote! {
392                Self::#var_ident(inner) => {
393                    let inner_idx_0 = <#inner_type as ::miniextendr_api::RFactor>::to_level_index(inner) - 1;
394                    #outer_idx_lit * #n_inner_lit + inner_idx_0 + 1
395                }
396            }
397        })
398        .collect();
399
400    // Generate from_level_index match arms
401    // outer_idx_0 = (idx - 1) / n_inner
402    // inner_idx_1 = (idx - 1) % n_inner + 1
403    let from_index_arms: Vec<_> = (0..n_outer)
404        .map(|outer_idx| {
405            let var_ident = &variant_idents[outer_idx];
406            let start_idx = (outer_idx * n_inner + 1) as i32;
407            let end_idx = ((outer_idx + 1) * n_inner) as i32;
408            quote! {
409                #start_idx..=#end_idx => {
410                    let inner_idx_1 = (idx - 1) % #n_inner_lit + 1;
411                    <#inner_type as ::miniextendr_api::RFactor>::from_level_index(inner_idx_1)
412                        .map(Self::#var_ident)
413                }
414            }
415        })
416        .collect();
417
418    // Generate inner level strings for the const assertion
419    let inner_level_strs: Vec<&str> = inner_levels.iter().map(|s| s.as_str()).collect();
420
421    Ok(quote! {
422        // Compile-time assertion: verify specified inner levels match the actual inner type's CHOICES.
423        // This catches mismatches between the `interaction = [...]` attribute and the inner type.
424        const _: () = {
425            const ACTUAL: &[&str] = <#inner_type as ::miniextendr_api::match_arg::MatchArg>::CHOICES;
426            const EXPECTED: &[&str] = &[#(#inner_level_strs),*];
427
428            // Check level count
429            assert!(
430                ACTUAL.len() == EXPECTED.len(),
431                "interaction factor: inner type level count mismatch"
432            );
433
434            // Check each level matches (const string comparison)
435            let mut i = 0;
436            while i < ACTUAL.len() {
437                let actual_bytes = ACTUAL[i].as_bytes();
438                let expected_bytes = EXPECTED[i].as_bytes();
439                assert!(
440                    actual_bytes.len() == expected_bytes.len(),
441                    "interaction factor: inner type level string length mismatch"
442                );
443                let mut j = 0;
444                while j < actual_bytes.len() {
445                    assert!(
446                        actual_bytes[j] == expected_bytes[j],
447                        "interaction factor: inner type level string content mismatch"
448                    );
449                    j += 1;
450                }
451                i += 1;
452            }
453        };
454
455        impl #impl_generics ::miniextendr_api::match_arg::MatchArg for #name #ty_generics #where_clause {
456            const CHOICES: &'static [&'static str] = &[#(#combined_level_strs),*];
457
458            fn from_choice(choice: &str) -> Option<Self> {
459                let idx_1 = Self::CHOICES.iter().position(|&l| l == choice).map(|i| i as i32 + 1)?;
460                <Self as ::miniextendr_api::RFactor>::from_level_index(idx_1)
461            }
462
463            fn to_choice(self) -> &'static str {
464                Self::CHOICES[(<Self as ::miniextendr_api::RFactor>::to_level_index(self) - 1) as usize]
465            }
466        }
467
468        impl #impl_generics ::miniextendr_api::RFactor for #name #ty_generics #where_clause {
469            fn to_level_index(self) -> i32 {
470                match self {
471                    #(#to_index_arms)*
472                }
473            }
474
475            fn from_level_index(idx: i32) -> Option<Self> {
476                match idx {
477                    #(#from_index_arms)*
478                    _ => None,
479                }
480            }
481        }
482
483        impl #impl_generics ::miniextendr_api::IntoR for #name #ty_generics #where_clause {
484            type Error = std::convert::Infallible;
485
486            fn try_into_sexp(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
487                Ok(self.into_sexp())
488            }
489
490            unsafe fn try_into_sexp_unchecked(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
491                self.try_into_sexp()
492            }
493
494            fn into_sexp(self) -> ::miniextendr_api::ffi::SEXP {
495                static LEVELS_CACHE: ::std::sync::OnceLock<::miniextendr_api::ffi::SEXP> =
496                    ::std::sync::OnceLock::new();
497                let levels = *LEVELS_CACHE.get_or_init(|| {
498                    ::miniextendr_api::build_levels_sexp_cached(
499                        <Self as ::miniextendr_api::match_arg::MatchArg>::CHOICES
500                    )
501                });
502                ::miniextendr_api::build_factor(&[<Self as ::miniextendr_api::RFactor>::to_level_index(self)], levels)
503            }
504        }
505
506        impl #impl_generics ::miniextendr_api::TryFromSexp for #name #ty_generics #where_clause {
507            type Error = ::miniextendr_api::SexpError;
508
509            fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Result<Self, Self::Error> {
510                ::miniextendr_api::factor_from_sexp(sexp)
511            }
512        }
513    })
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[test]
521    fn test_interaction_levels_generation() {
522        let input: DeriveInput = syn::parse_quote! {
523            #[r_factor(interaction = ["Small", "Large"])]
524            enum ColorSize {
525                Red(Size),
526                Green(Size),
527                Blue(Size),
528            }
529        };
530
531        let result = derive_r_factor(input).unwrap();
532        let code = result.to_string();
533
534        // Check that combined levels are generated
535        assert!(code.contains("Red.Small"));
536        assert!(code.contains("Red.Large"));
537        assert!(code.contains("Green.Small"));
538        assert!(code.contains("Green.Large"));
539        assert!(code.contains("Blue.Small"));
540        assert!(code.contains("Blue.Large"));
541
542        // Check that const assertion is generated for level validation
543        assert!(code.contains("const _ : () ="));
544        assert!(code.contains("ACTUAL"));
545        assert!(code.contains("EXPECTED"));
546        assert!(code.contains("inner type level count mismatch"));
547    }
548
549    #[test]
550    fn test_interaction_custom_separator() {
551        let input: DeriveInput = syn::parse_quote! {
552            #[r_factor(interaction = ["X", "Y"], sep = "_")]
553            enum AB {
554                A(Inner),
555                B(Inner),
556            }
557        };
558
559        let result = derive_r_factor(input).unwrap();
560        let code = result.to_string();
561
562        // Check custom separator
563        assert!(code.contains("A_X"));
564        assert!(code.contains("A_Y"));
565        assert!(code.contains("B_X"));
566        assert!(code.contains("B_Y"));
567    }
568
569    #[test]
570    fn test_interaction_with_rename() {
571        let input: DeriveInput = syn::parse_quote! {
572            #[r_factor(interaction = ["S", "L"], rename_all = "lower")]
573            enum ColorSize {
574                Red(Size),
575                Green(Size),
576            }
577        };
578
579        let result = derive_r_factor(input).unwrap();
580        let code = result.to_string();
581
582        // Check renamed outer levels
583        assert!(code.contains("red.S"));
584        assert!(code.contains("red.L"));
585        assert!(code.contains("green.S"));
586        assert!(code.contains("green.L"));
587    }
588}