1use proc_macro2::TokenStream;
55use quote::quote;
56use syn::{Data, DeriveInput, Fields, Type};
57
58use crate::naming::apply_rename_all;
59
60#[derive(Default)]
62struct RFactorAttrs {
63 rename: Option<String>,
65 rename_all: Option<String>,
68 interaction: Option<Vec<String>>,
72 sep: Option<String>,
75}
76
77fn parse_r_factor_attrs(attrs: &[syn::Attribute]) -> syn::Result<RFactorAttrs> {
82 let mut result = RFactorAttrs::default();
83
84 for attr in attrs {
85 if attr.path().is_ident("r_factor") {
86 attr.parse_nested_meta(|meta| {
87 if meta.path.is_ident("rename") {
88 let value: syn::LitStr = meta.value()?.parse()?;
89 result.rename = Some(value.value());
90 } else if meta.path.is_ident("rename_all") {
91 let value: syn::LitStr = meta.value()?.parse()?;
92 result.rename_all = Some(value.value());
93 } else if meta.path.is_ident("interaction") {
94 let _eq: syn::Token![=] = meta.input.parse()?;
96 let content;
97 syn::bracketed!(content in meta.input);
98 let levels: syn::punctuated::Punctuated<syn::LitStr, syn::Token![,]> =
99 content.parse_terminated(|input| input.parse(), syn::Token![,])?;
100 result.interaction = Some(levels.iter().map(|s| s.value()).collect());
101 } else if meta.path.is_ident("sep") {
102 let value: syn::LitStr = meta.value()?.parse()?;
103 result.sep = Some(value.value());
104 } else {
105 return Err(meta.error("unknown r_factor attribute"));
106 }
107 Ok(())
108 })?;
109 }
110 }
111
112 Ok(result)
113}
114
115pub fn derive_r_factor(input: DeriveInput) -> syn::Result<TokenStream> {
129 let name = &input.ident;
130 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
131
132 let attrs = parse_r_factor_attrs(&input.attrs)?;
134
135 let variants = match &input.data {
137 Data::Enum(data) => &data.variants,
138 Data::Struct(_) => {
139 return Err(syn::Error::new_spanned(
140 &input,
141 "#[derive(RFactor)] can only be applied to enums",
142 ));
143 }
144 Data::Union(_) => {
145 return Err(syn::Error::new_spanned(
146 &input,
147 "#[derive(RFactor)] can only be applied to enums",
148 ));
149 }
150 };
151
152 if let Some(inner_levels) = &attrs.interaction {
154 derive_interaction_factor(
155 name,
156 &impl_generics,
157 &ty_generics,
158 where_clause,
159 variants,
160 inner_levels,
161 attrs.sep.as_deref().unwrap_or("."),
162 attrs.rename_all.as_deref(),
163 )
164 } else {
165 derive_simple_factor(
166 name,
167 &impl_generics,
168 &ty_generics,
169 where_clause,
170 variants,
171 attrs.rename_all.as_deref(),
172 )
173 }
174}
175
176fn derive_simple_factor(
188 name: &syn::Ident,
189 impl_generics: &syn::ImplGenerics,
190 ty_generics: &syn::TypeGenerics,
191 where_clause: Option<&syn::WhereClause>,
192 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
193 rename_all: Option<&str>,
194) -> syn::Result<TokenStream> {
195 let mut level_names = Vec::new();
196 let mut variant_idents = Vec::new();
197
198 for variant in variants {
199 if !matches!(variant.fields, Fields::Unit) {
201 return Err(syn::Error::new_spanned(
202 variant,
203 "#[derive(RFactor)] only supports fieldless (C-style) enum variants \
204 (use #[r_factor(interaction = [...])] for tuple variants)",
205 ));
206 }
207
208 let var_attrs = parse_r_factor_attrs(&variant.attrs)?;
210
211 let level_name = if let Some(r) = var_attrs.rename {
213 r
214 } else {
215 apply_rename_all(&variant.ident.to_string(), rename_all)
216 };
217
218 level_names.push(level_name);
219 variant_idents.push(&variant.ident);
220 }
221
222 let indices: Vec<i32> = (1..=variant_idents.len() as i32).collect();
224 let level_name_strs: Vec<&str> = level_names.iter().map(|s| s.as_str()).collect();
225
226 Ok(quote! {
227 impl #impl_generics ::miniextendr_api::match_arg::MatchArg for #name #ty_generics #where_clause {
228 const CHOICES: &'static [&'static str] = &[#(#level_name_strs),*];
229
230 fn from_choice(choice: &str) -> Option<Self> {
231 match choice {
232 #(#level_name_strs => Some(Self::#variant_idents),)*
233 _ => None,
234 }
235 }
236
237 fn to_choice(self) -> &'static str {
238 match self {
239 #(Self::#variant_idents => #level_name_strs,)*
240 }
241 }
242 }
243
244 impl #impl_generics ::miniextendr_api::RFactor for #name #ty_generics #where_clause {
245 fn to_level_index(self) -> i32 {
246 match self {
247 #(Self::#variant_idents => #indices,)*
248 }
249 }
250
251 fn from_level_index(idx: i32) -> Option<Self> {
252 match idx {
253 #(#indices => Some(Self::#variant_idents),)*
254 _ => None,
255 }
256 }
257 }
258
259 impl #impl_generics ::miniextendr_api::IntoR for #name #ty_generics #where_clause {
260 type Error = std::convert::Infallible;
261
262 fn try_into_sexp(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
263 Ok(self.into_sexp())
264 }
265
266 unsafe fn try_into_sexp_unchecked(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
267 self.try_into_sexp()
268 }
269
270 fn into_sexp(self) -> ::miniextendr_api::ffi::SEXP {
271 static LEVELS_CACHE: ::std::sync::OnceLock<::miniextendr_api::ffi::SEXP> =
272 ::std::sync::OnceLock::new();
273 let levels = *LEVELS_CACHE.get_or_init(|| {
274 ::miniextendr_api::build_levels_sexp_cached(
275 <Self as ::miniextendr_api::match_arg::MatchArg>::CHOICES
276 )
277 });
278 ::miniextendr_api::build_factor(&[<Self as ::miniextendr_api::RFactor>::to_level_index(self)], levels)
279 }
280 }
281
282 impl #impl_generics ::miniextendr_api::TryFromSexp for #name #ty_generics #where_clause {
283 type Error = ::miniextendr_api::SexpError;
284
285 fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Result<Self, Self::Error> {
286 ::miniextendr_api::factor_from_sexp(sexp)
287 }
288 }
289 })
290}
291
292#[allow(clippy::too_many_arguments)] fn derive_interaction_factor(
312 name: &syn::Ident,
313 impl_generics: &syn::ImplGenerics,
314 ty_generics: &syn::TypeGenerics,
315 where_clause: Option<&syn::WhereClause>,
316 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
317 inner_levels: &[String],
318 sep: &str,
319 rename_all: Option<&str>,
320) -> syn::Result<TokenStream> {
321 let mut outer_names = Vec::new();
322 let mut variant_idents = Vec::new();
323 let mut inner_type: Option<Type> = None;
324
325 for variant in variants {
326 let field_ty = match &variant.fields {
328 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
329 fields.unnamed.first().unwrap().ty.clone()
330 }
331 _ => {
332 return Err(syn::Error::new_spanned(
333 variant,
334 "interaction factors require single-field tuple variants: Variant(InnerType)",
335 ));
336 }
337 };
338
339 if let Some(ref existing) = inner_type {
341 if field_ty != *existing {
342 return Err(syn::Error::new_spanned(
343 &variant.fields,
344 "all variants must have the same inner type",
345 ));
346 }
347 } else {
348 inner_type = Some(field_ty);
349 }
350
351 let var_attrs = parse_r_factor_attrs(&variant.attrs)?;
353 let outer_name = if let Some(r) = var_attrs.rename {
354 r
355 } else {
356 apply_rename_all(&variant.ident.to_string(), rename_all)
357 };
358
359 outer_names.push(outer_name);
360 variant_idents.push(&variant.ident);
361 }
362
363 let inner_type = inner_type.ok_or_else(|| {
364 syn::Error::new_spanned(name, "interaction factor must have at least one variant")
365 })?;
366
367 let n_outer = outer_names.len();
368 let n_inner = inner_levels.len();
369
370 let mut combined_levels = Vec::new();
374 for outer_name in &outer_names {
375 for inner_name in inner_levels {
376 let combined = format!("{}{}{}", outer_name, sep, inner_name);
378 combined_levels.push(combined);
379 }
380 }
381 let combined_level_strs: Vec<&str> = combined_levels.iter().map(|s| s.as_str()).collect();
382
383 let n_inner_lit = n_inner as i32;
386 let to_index_arms: Vec<_> = variant_idents
387 .iter()
388 .enumerate()
389 .map(|(outer_idx, var_ident)| {
390 let outer_idx_lit = outer_idx as i32;
391 quote! {
392 Self::#var_ident(inner) => {
393 let inner_idx_0 = <#inner_type as ::miniextendr_api::RFactor>::to_level_index(inner) - 1;
394 #outer_idx_lit * #n_inner_lit + inner_idx_0 + 1
395 }
396 }
397 })
398 .collect();
399
400 let from_index_arms: Vec<_> = (0..n_outer)
404 .map(|outer_idx| {
405 let var_ident = &variant_idents[outer_idx];
406 let start_idx = (outer_idx * n_inner + 1) as i32;
407 let end_idx = ((outer_idx + 1) * n_inner) as i32;
408 quote! {
409 #start_idx..=#end_idx => {
410 let inner_idx_1 = (idx - 1) % #n_inner_lit + 1;
411 <#inner_type as ::miniextendr_api::RFactor>::from_level_index(inner_idx_1)
412 .map(Self::#var_ident)
413 }
414 }
415 })
416 .collect();
417
418 let inner_level_strs: Vec<&str> = inner_levels.iter().map(|s| s.as_str()).collect();
420
421 Ok(quote! {
422 const _: () = {
425 const ACTUAL: &[&str] = <#inner_type as ::miniextendr_api::match_arg::MatchArg>::CHOICES;
426 const EXPECTED: &[&str] = &[#(#inner_level_strs),*];
427
428 assert!(
430 ACTUAL.len() == EXPECTED.len(),
431 "interaction factor: inner type level count mismatch"
432 );
433
434 let mut i = 0;
436 while i < ACTUAL.len() {
437 let actual_bytes = ACTUAL[i].as_bytes();
438 let expected_bytes = EXPECTED[i].as_bytes();
439 assert!(
440 actual_bytes.len() == expected_bytes.len(),
441 "interaction factor: inner type level string length mismatch"
442 );
443 let mut j = 0;
444 while j < actual_bytes.len() {
445 assert!(
446 actual_bytes[j] == expected_bytes[j],
447 "interaction factor: inner type level string content mismatch"
448 );
449 j += 1;
450 }
451 i += 1;
452 }
453 };
454
455 impl #impl_generics ::miniextendr_api::match_arg::MatchArg for #name #ty_generics #where_clause {
456 const CHOICES: &'static [&'static str] = &[#(#combined_level_strs),*];
457
458 fn from_choice(choice: &str) -> Option<Self> {
459 let idx_1 = Self::CHOICES.iter().position(|&l| l == choice).map(|i| i as i32 + 1)?;
460 <Self as ::miniextendr_api::RFactor>::from_level_index(idx_1)
461 }
462
463 fn to_choice(self) -> &'static str {
464 Self::CHOICES[(<Self as ::miniextendr_api::RFactor>::to_level_index(self) - 1) as usize]
465 }
466 }
467
468 impl #impl_generics ::miniextendr_api::RFactor for #name #ty_generics #where_clause {
469 fn to_level_index(self) -> i32 {
470 match self {
471 #(#to_index_arms)*
472 }
473 }
474
475 fn from_level_index(idx: i32) -> Option<Self> {
476 match idx {
477 #(#from_index_arms)*
478 _ => None,
479 }
480 }
481 }
482
483 impl #impl_generics ::miniextendr_api::IntoR for #name #ty_generics #where_clause {
484 type Error = std::convert::Infallible;
485
486 fn try_into_sexp(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
487 Ok(self.into_sexp())
488 }
489
490 unsafe fn try_into_sexp_unchecked(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
491 self.try_into_sexp()
492 }
493
494 fn into_sexp(self) -> ::miniextendr_api::ffi::SEXP {
495 static LEVELS_CACHE: ::std::sync::OnceLock<::miniextendr_api::ffi::SEXP> =
496 ::std::sync::OnceLock::new();
497 let levels = *LEVELS_CACHE.get_or_init(|| {
498 ::miniextendr_api::build_levels_sexp_cached(
499 <Self as ::miniextendr_api::match_arg::MatchArg>::CHOICES
500 )
501 });
502 ::miniextendr_api::build_factor(&[<Self as ::miniextendr_api::RFactor>::to_level_index(self)], levels)
503 }
504 }
505
506 impl #impl_generics ::miniextendr_api::TryFromSexp for #name #ty_generics #where_clause {
507 type Error = ::miniextendr_api::SexpError;
508
509 fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Result<Self, Self::Error> {
510 ::miniextendr_api::factor_from_sexp(sexp)
511 }
512 }
513 })
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn test_interaction_levels_generation() {
522 let input: DeriveInput = syn::parse_quote! {
523 #[r_factor(interaction = ["Small", "Large"])]
524 enum ColorSize {
525 Red(Size),
526 Green(Size),
527 Blue(Size),
528 }
529 };
530
531 let result = derive_r_factor(input).unwrap();
532 let code = result.to_string();
533
534 assert!(code.contains("Red.Small"));
536 assert!(code.contains("Red.Large"));
537 assert!(code.contains("Green.Small"));
538 assert!(code.contains("Green.Large"));
539 assert!(code.contains("Blue.Small"));
540 assert!(code.contains("Blue.Large"));
541
542 assert!(code.contains("const _ : () ="));
544 assert!(code.contains("ACTUAL"));
545 assert!(code.contains("EXPECTED"));
546 assert!(code.contains("inner type level count mismatch"));
547 }
548
549 #[test]
550 fn test_interaction_custom_separator() {
551 let input: DeriveInput = syn::parse_quote! {
552 #[r_factor(interaction = ["X", "Y"], sep = "_")]
553 enum AB {
554 A(Inner),
555 B(Inner),
556 }
557 };
558
559 let result = derive_r_factor(input).unwrap();
560 let code = result.to_string();
561
562 assert!(code.contains("A_X"));
564 assert!(code.contains("A_Y"));
565 assert!(code.contains("B_X"));
566 assert!(code.contains("B_Y"));
567 }
568
569 #[test]
570 fn test_interaction_with_rename() {
571 let input: DeriveInput = syn::parse_quote! {
572 #[r_factor(interaction = ["S", "L"], rename_all = "lower")]
573 enum ColorSize {
574 Red(Size),
575 Green(Size),
576 }
577 };
578
579 let result = derive_r_factor(input).unwrap();
580 let code = result.to_string();
581
582 assert!(code.contains("red.S"));
584 assert!(code.contains("red.L"));
585 assert!(code.contains("green.S"));
586 assert!(code.contains("green.L"));
587 }
588}