miniextendr_macros/
match_arg_derive.rs1use proc_macro2::TokenStream;
26use quote::quote;
27use syn::{Data, DeriveInput, Fields};
28
29use crate::naming::apply_rename_all;
30
31#[derive(Default)]
33struct MatchArgAttrs {
34 rename: Option<String>,
36 rename_all: Option<String>,
39}
40
41fn parse_match_arg_attrs(attrs: &[syn::Attribute]) -> syn::Result<MatchArgAttrs> {
47 let mut result = MatchArgAttrs::default();
48
49 for attr in attrs {
50 if attr.path().is_ident("match_arg") {
51 attr.parse_nested_meta(|meta| {
52 if meta.path.is_ident("rename") {
53 let value: syn::LitStr = meta.value()?.parse()?;
54 result.rename = Some(value.value());
55 } else if meta.path.is_ident("rename_all") {
56 let value: syn::LitStr = meta.value()?.parse()?;
57 let val = value.value();
58 match val.as_str() {
59 "snake_case" | "kebab-case" | "lower" | "upper" => {}
60 _ => {
61 return Err(meta.error(
62 "unsupported rename_all value; expected one of: \
63 snake_case, kebab-case, lower, upper",
64 ));
65 }
66 }
67 result.rename_all = Some(val);
68 } else {
69 return Err(meta
70 .error("unknown match_arg attribute; expected `rename` or `rename_all`"));
71 }
72 Ok(())
73 })?;
74 }
75 }
76
77 Ok(result)
78}
79
80pub fn derive_match_arg(input: DeriveInput) -> syn::Result<TokenStream> {
103 let name = &input.ident;
104 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
105
106 if !input.generics.params.is_empty() {
108 return Err(syn::Error::new_spanned(
109 &input.generics,
110 "#[derive(MatchArg)] does not support generic enums",
111 ));
112 }
113
114 let attrs = parse_match_arg_attrs(&input.attrs)?;
116
117 let variants = match &input.data {
119 Data::Enum(data) => &data.variants,
120 Data::Struct(_) => {
121 return Err(syn::Error::new_spanned(
122 &input,
123 "#[derive(MatchArg)] can only be applied to enums",
124 ));
125 }
126 Data::Union(_) => {
127 return Err(syn::Error::new_spanned(
128 &input,
129 "#[derive(MatchArg)] can only be applied to enums",
130 ));
131 }
132 };
133
134 if variants.is_empty() {
135 return Err(syn::Error::new_spanned(
136 &input,
137 "#[derive(MatchArg)] requires at least one variant",
138 ));
139 }
140
141 let mut choice_names = Vec::new();
142 let mut variant_idents = Vec::new();
143
144 for variant in variants {
145 if !matches!(variant.fields, Fields::Unit) {
147 return Err(syn::Error::new_spanned(
148 variant,
149 "#[derive(MatchArg)] only supports fieldless (C-style) enum variants",
150 ));
151 }
152
153 let var_attrs = parse_match_arg_attrs(&variant.attrs)?;
155
156 let choice_name = if let Some(r) = var_attrs.rename {
158 r
159 } else {
160 apply_rename_all(&variant.ident.to_string(), attrs.rename_all.as_deref())
161 };
162
163 choice_names.push(choice_name);
164 variant_idents.push(&variant.ident);
165 }
166
167 {
169 let mut seen = std::collections::HashSet::new();
170 for (i, name) in choice_names.iter().enumerate() {
171 if !seen.insert(name.as_str()) {
172 return Err(syn::Error::new_spanned(
173 &variants.iter().nth(i).unwrap().ident,
174 format!("duplicate choice name {:?} in #[derive(MatchArg)]", name),
175 ));
176 }
177 }
178 }
179
180 let choice_strs: Vec<&str> = choice_names.iter().map(|s| s.as_str()).collect();
181
182 Ok(quote! {
183 impl #impl_generics ::miniextendr_api::match_arg::MatchArg for #name #ty_generics #where_clause {
184 const CHOICES: &'static [&'static str] = &[#(#choice_strs),*];
185
186 fn from_choice(choice: &str) -> Option<Self> {
187 match choice {
188 #(#choice_strs => Some(Self::#variant_idents),)*
189 _ => None,
190 }
191 }
192
193 fn to_choice(self) -> &'static str {
194 match self {
195 #(Self::#variant_idents => #choice_strs,)*
196 }
197 }
198 }
199
200 impl #impl_generics ::miniextendr_api::TryFromSexp for #name #ty_generics #where_clause {
201 type Error = ::miniextendr_api::SexpError;
202
203 fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Result<Self, Self::Error> {
204 ::miniextendr_api::match_arg_from_sexp(sexp).map_err(Into::into)
205 }
206 }
207
208 impl #impl_generics ::miniextendr_api::IntoR for #name #ty_generics #where_clause {
209 type Error = std::convert::Infallible;
210
211 fn try_into_sexp(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
212 Ok(self.into_sexp())
213 }
214
215 unsafe fn try_into_sexp_unchecked(self) -> Result<::miniextendr_api::ffi::SEXP, Self::Error> {
216 self.try_into_sexp()
217 }
218
219 fn into_sexp(self) -> ::miniextendr_api::ffi::SEXP {
220 use ::miniextendr_api::match_arg::MatchArg;
221 self.to_choice().into_sexp()
222 }
223 }
224
225
226 })
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_simple_derive() {
235 let input: DeriveInput = syn::parse_quote! {
236 enum Mode {
237 Fast,
238 Safe,
239 Debug,
240 }
241 };
242
243 let result = derive_match_arg(input).unwrap();
244 let code = result.to_string();
245 assert!(code.contains("Fast"));
246 assert!(code.contains("Safe"));
247 assert!(code.contains("Debug"));
248 assert!(code.contains("CHOICES"));
249 assert!(code.contains("from_choice"));
250 assert!(code.contains("to_choice"));
251 }
252
253 #[test]
254 fn test_rename_all() {
255 let input: DeriveInput = syn::parse_quote! {
256 #[match_arg(rename_all = "snake_case")]
257 enum Mode {
258 FastMode,
259 SafeMode,
260 }
261 };
262
263 let result = derive_match_arg(input).unwrap();
264 let code = result.to_string();
265 assert!(code.contains("fast_mode"));
266 assert!(code.contains("safe_mode"));
267 }
268
269 #[test]
270 fn test_rename_variant() {
271 let input: DeriveInput = syn::parse_quote! {
272 enum Priority {
273 #[match_arg(rename = "lo")]
274 Low,
275 #[match_arg(rename = "hi")]
276 High,
277 }
278 };
279
280 let result = derive_match_arg(input).unwrap();
281 let code = result.to_string();
282 assert!(code.contains("\"lo\""));
283 assert!(code.contains("\"hi\""));
284 }
285
286 #[test]
287 fn test_reject_fields() {
288 let input: DeriveInput = syn::parse_quote! {
289 enum Bad {
290 A(i32),
291 }
292 };
293
294 let result = derive_match_arg(input);
295 assert!(result.is_err());
296 }
297
298 #[test]
299 fn test_reject_struct() {
300 let input: DeriveInput = syn::parse_quote! {
301 struct Bad;
302 };
303
304 let result = derive_match_arg(input);
305 assert!(result.is_err());
306 }
307
308 #[test]
309 fn test_reject_empty() {
310 let input: DeriveInput = syn::parse_quote! {
311 enum Empty {}
312 };
313
314 let result = derive_match_arg(input);
315 assert!(result.is_err());
316 }
317
318 #[test]
319 fn test_into_r_impl_present() {
320 let input: DeriveInput = syn::parse_quote! {
324 enum Mode {
325 Fast,
326 Safe,
327 Debug,
328 }
329 };
330
331 let result = derive_match_arg(input).unwrap();
332 let code = result.to_string();
333 assert!(code.contains("IntoR for Mode"));
335 assert!(!code.contains("IntoR for :: std :: vec :: Vec < Mode >"));
337 assert!(!code.contains("match_arg_vec_into_sexp"));
338 }
339
340 #[test]
341 fn test_duplicate_choice_names() {
342 let input: DeriveInput = syn::parse_quote! {
343 enum Dup {
344 #[match_arg(rename = "same")]
345 A,
346 #[match_arg(rename = "same")]
347 B,
348 }
349 };
350
351 let result = derive_match_arg(input);
352 assert!(result.is_err());
353 }
354}