Skip to main content

miniextendr_macros/
match_arg_derive.rs

1//! # `#[derive(MatchArg)]` - Enum ↔ R String with `match.arg` Support
2//!
3//! This module implements the `#[derive(MatchArg)]` macro which generates
4//! the `MatchArg` trait implementation for C-style enums, enabling automatic
5//! conversion between Rust enums and R character strings with partial matching.
6//!
7//! ## Usage
8//!
9//! ```ignore
10//! #[derive(Copy, Clone, MatchArg)]
11//! enum Mode {
12//!     Fast,
13//!     Safe,
14//!     Debug,
15//! }
16//!
17//! // Generates impl MatchArg for Mode, TryFromSexp for Mode, IntoR for Mode.
18//! ```
19//!
20//! ## Attributes
21//!
22//! - `#[match_arg(rename = "name")]` - Rename a variant's choice string
23//! - `#[match_arg(rename_all = "snake_case")]` - Rename all variants (snake_case, kebab-case, lower, upper)
24
25use proc_macro2::TokenStream;
26use quote::quote;
27use syn::{Data, DeriveInput, Fields};
28
29use crate::naming::apply_rename_all;
30
31/// Parsed `#[match_arg(...)]` attributes from an enum or variant.
32#[derive(Default)]
33struct MatchArgAttrs {
34    /// Per-variant rename: `#[match_arg(rename = "custom")]`.
35    rename: Option<String>,
36    /// Enum-level rename-all: `#[match_arg(rename_all = "snake_case")]`.
37    /// Applied to all variants that don't have an explicit `rename`.
38    rename_all: Option<String>,
39}
40
41/// Parse `#[match_arg(...)]` attributes from a list of `syn::Attribute`.
42///
43/// Extracts `rename` and `rename_all` keys. Validates that `rename_all` uses
44/// one of the supported modes: `snake_case`, `kebab-case`, `lower`, `upper`.
45/// Returns `Err` for unknown attribute keys or unsupported `rename_all` values.
46fn parse_match_arg_attrs(attrs: &[syn::Attribute]) -> syn::Result<MatchArgAttrs> {
47    let mut result = MatchArgAttrs::default();
48
49    for attr in attrs {
50        if attr.path().is_ident("match_arg") {
51            attr.parse_nested_meta(|meta| {
52                if meta.path.is_ident("rename") {
53                    let value: syn::LitStr = meta.value()?.parse()?;
54                    result.rename = Some(value.value());
55                } else if meta.path.is_ident("rename_all") {
56                    let value: syn::LitStr = meta.value()?.parse()?;
57                    let val = value.value();
58                    match val.as_str() {
59                        "snake_case" | "kebab-case" | "lower" | "upper" => {}
60                        _ => {
61                            return Err(meta.error(
62                                "unsupported rename_all value; expected one of: \
63                                 snake_case, kebab-case, lower, upper",
64                            ));
65                        }
66                    }
67                    result.rename_all = Some(val);
68                } else {
69                    return Err(meta
70                        .error("unknown match_arg attribute; expected `rename` or `rename_all`"));
71                }
72                Ok(())
73            })?;
74        }
75    }
76
77    Ok(result)
78}
79
80/// Main entry point for `#[derive(MatchArg)]`.
81///
82/// Generates three trait implementations:
83/// - `impl MatchArg` -- provides `CHOICES` (static string slice), `from_choice`, `to_choice`
84/// - `impl TryFromSexp` -- converts R character scalar to enum variant via `match_arg_from_sexp`
85/// - `impl IntoR` -- converts enum variant to R character scalar via `to_choice().into_sexp()`
86///
87/// `impl IntoR for Vec<Self>` is provided automatically by the blanket
88/// `impl<T: MatchArg> IntoR for Vec<T>` in `miniextendr-api::match_arg`,
89/// so returning `Vec<EnumName>` from a `#[miniextendr]` function works
90/// without any extra code in the user's crate.
91///
92/// Validates:
93/// - Only enums are accepted (not structs or unions)
94/// - Generic enums are rejected
95/// - At least one variant is required
96/// - Only fieldless (C-style) variants are allowed
97/// - No duplicate choice names after renaming
98///
99/// Choice names default to variant identifiers, optionally transformed by
100/// `#[match_arg(rename_all = "...")]` or overridden per-variant with
101/// `#[match_arg(rename = "...")]`.
102pub fn derive_match_arg(input: DeriveInput) -> syn::Result<TokenStream> {
103    let name = &input.ident;
104    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
105
106    // Reject generics for v1
107    if !input.generics.params.is_empty() {
108        return Err(syn::Error::new_spanned(
109            &input.generics,
110            "#[derive(MatchArg)] does not support generic enums",
111        ));
112    }
113
114    // Parse enum-level attributes
115    let attrs = parse_match_arg_attrs(&input.attrs)?;
116
117    // Get enum variants
118    let variants = match &input.data {
119        Data::Enum(data) => &data.variants,
120        Data::Struct(_) => {
121            return Err(syn::Error::new_spanned(
122                &input,
123                "#[derive(MatchArg)] can only be applied to enums",
124            ));
125        }
126        Data::Union(_) => {
127            return Err(syn::Error::new_spanned(
128                &input,
129                "#[derive(MatchArg)] can only be applied to enums",
130            ));
131        }
132    };
133
134    if variants.is_empty() {
135        return Err(syn::Error::new_spanned(
136            &input,
137            "#[derive(MatchArg)] requires at least one variant",
138        ));
139    }
140
141    let mut choice_names = Vec::new();
142    let mut variant_idents = Vec::new();
143
144    for variant in variants {
145        // Only allow unit variants (fieldless)
146        if !matches!(variant.fields, Fields::Unit) {
147            return Err(syn::Error::new_spanned(
148                variant,
149                "#[derive(MatchArg)] only supports fieldless (C-style) enum variants",
150            ));
151        }
152
153        // Parse variant-level attributes
154        let var_attrs = parse_match_arg_attrs(&variant.attrs)?;
155
156        // Determine choice name
157        let choice_name = if let Some(r) = var_attrs.rename {
158            r
159        } else {
160            apply_rename_all(&variant.ident.to_string(), attrs.rename_all.as_deref())
161        };
162
163        choice_names.push(choice_name);
164        variant_idents.push(&variant.ident);
165    }
166
167    // Check for duplicate choice names
168    {
169        let mut seen = std::collections::HashSet::new();
170        for (i, name) in choice_names.iter().enumerate() {
171            if !seen.insert(name.as_str()) {
172                return Err(syn::Error::new_spanned(
173                    &variants.iter().nth(i).unwrap().ident,
174                    format!("duplicate choice name {:?} in #[derive(MatchArg)]", name),
175                ));
176            }
177        }
178    }
179
180    let choice_strs: Vec<&str> = choice_names.iter().map(|s| s.as_str()).collect();
181
182    Ok(quote! {
183        impl #impl_generics ::miniextendr_api::match_arg::MatchArg for #name #ty_generics #where_clause {
184            const CHOICES: &'static [&'static str] = &[#(#choice_strs),*];
185
186            fn from_choice(choice: &str) -> Option<Self> {
187                match choice {
188                    #(#choice_strs => Some(Self::#variant_idents),)*
189                    _ => None,
190                }
191            }
192
193            fn to_choice(self) -> &'static str {
194                match self {
195                    #(Self::#variant_idents => #choice_strs,)*
196                }
197            }
198        }
199
200        impl #impl_generics ::miniextendr_api::TryFromSexp for #name #ty_generics #where_clause {
201            type Error = ::miniextendr_api::SexpError;
202
203            fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Result<Self, Self::Error> {
204                ::miniextendr_api::match_arg_from_sexp(sexp).map_err(Into::into)
205            }
206        }
207
208        impl #impl_generics ::miniextendr_api::IntoR for #name #ty_generics #where_clause {
209            type Error = std::convert::Infallible;
210
211            fn try_into_sexp(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
212                Ok(self.into_sexp())
213            }
214
215            unsafe fn try_into_sexp_unchecked(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
216                self.try_into_sexp()
217            }
218
219            fn into_sexp(self) -> ::miniextendr_api::ffi::SEXP {
220                use ::miniextendr_api::match_arg::MatchArg;
221                self.to_choice().into_sexp()
222            }
223        }
224
225
226    })
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_simple_derive() {
235        let input: DeriveInput = syn::parse_quote! {
236            enum Mode {
237                Fast,
238                Safe,
239                Debug,
240            }
241        };
242
243        let result = derive_match_arg(input).unwrap();
244        let code = result.to_string();
245        assert!(code.contains("Fast"));
246        assert!(code.contains("Safe"));
247        assert!(code.contains("Debug"));
248        assert!(code.contains("CHOICES"));
249        assert!(code.contains("from_choice"));
250        assert!(code.contains("to_choice"));
251    }
252
253    #[test]
254    fn test_rename_all() {
255        let input: DeriveInput = syn::parse_quote! {
256            #[match_arg(rename_all = "snake_case")]
257            enum Mode {
258                FastMode,
259                SafeMode,
260            }
261        };
262
263        let result = derive_match_arg(input).unwrap();
264        let code = result.to_string();
265        assert!(code.contains("fast_mode"));
266        assert!(code.contains("safe_mode"));
267    }
268
269    #[test]
270    fn test_rename_variant() {
271        let input: DeriveInput = syn::parse_quote! {
272            enum Priority {
273                #[match_arg(rename = "lo")]
274                Low,
275                #[match_arg(rename = "hi")]
276                High,
277            }
278        };
279
280        let result = derive_match_arg(input).unwrap();
281        let code = result.to_string();
282        assert!(code.contains("\"lo\""));
283        assert!(code.contains("\"hi\""));
284    }
285
286    #[test]
287    fn test_reject_fields() {
288        let input: DeriveInput = syn::parse_quote! {
289            enum Bad {
290                A(i32),
291            }
292        };
293
294        let result = derive_match_arg(input);
295        assert!(result.is_err());
296    }
297
298    #[test]
299    fn test_reject_struct() {
300        let input: DeriveInput = syn::parse_quote! {
301            struct Bad;
302        };
303
304        let result = derive_match_arg(input);
305        assert!(result.is_err());
306    }
307
308    #[test]
309    fn test_reject_empty() {
310        let input: DeriveInput = syn::parse_quote! {
311            enum Empty {}
312        };
313
314        let result = derive_match_arg(input);
315        assert!(result.is_err());
316    }
317
318    #[test]
319    fn test_into_r_impl_present() {
320        // The derive emits IntoR for the scalar EnumName.
321        // Vec<EnumName> IntoR is covered by the blanket impl<T: MatchArg> IntoR for Vec<T>
322        // in miniextendr-api — it is NOT emitted by the derive.
323        let input: DeriveInput = syn::parse_quote! {
324            enum Mode {
325                Fast,
326                Safe,
327                Debug,
328            }
329        };
330
331        let result = derive_match_arg(input).unwrap();
332        let code = result.to_string();
333        // Scalar IntoR impl must be present
334        assert!(code.contains("IntoR for Mode"));
335        // Vec<Mode> IntoR must NOT be emitted by the derive (covered by blanket in miniextendr-api)
336        assert!(!code.contains("IntoR for :: std :: vec :: Vec < Mode >"));
337        assert!(!code.contains("match_arg_vec_into_sexp"));
338    }
339
340    #[test]
341    fn test_duplicate_choice_names() {
342        let input: DeriveInput = syn::parse_quote! {
343            enum Dup {
344                #[match_arg(rename = "same")]
345                A,
346                #[match_arg(rename = "same")]
347                B,
348            }
349        };
350
351        let result = derive_match_arg(input);
352        assert!(result.is_err());
353    }
354}