miniextendr_macros/miniextendr_trait.rs
1//! # Trait Support for `#[miniextendr]`
2//!
3//! This module handles `#[miniextendr]` applied to trait definitions,
4//! generating the ABI infrastructure for cross-package trait dispatch.
5//!
6//! ## Overview
7//!
8//! When `#[miniextendr]` is applied to a trait, it generates:
9//!
10//! 1. **Type tag constant** (`TAG_<TraitName>`) - 128-bit identifier for runtime type checking
11//! 2. **Vtable struct** (`<TraitName>VTable`) - Function pointer table for method dispatch
12//! 3. **View struct** (`<TraitName>View`) - Runtime wrapper combining data pointer and vtable
13//! 4. **Method shims** - `extern "C"` functions that convert SEXP arguments and call methods
14//! 5. **Vtable builder** - `__<trait>_build_vtable::<T>()` for impl blocks
15//!
16//! ## Usage
17//!
18//! ```ignore
19//! #[miniextendr]
20//! pub trait Counter {
21//! fn value(&self) -> i32;
22//! fn increment(&mut self);
23//! fn add(&mut self, n: i32);
24//! }
25//! ```
26//!
27//! Generates (conceptually):
28//!
29//! ```text
30//! // Original trait (passed through)
31//! pub trait Counter {
32//! fn value(&self) -> i32;
33//! fn increment(&mut self);
34//! fn add(&mut self, n: i32);
35//! }
36//!
37//! // Type tag for runtime identification
38//! pub const TAG_COUNTER: mx_tag = mx_tag::new(0x..., 0x...);
39//!
40//! // Vtable with one entry per method
41//! #[repr(C)]
42//! pub struct CounterVTable {
43//! pub value: mx_meth,
44//! pub increment: mx_meth,
45//! pub add: mx_meth,
46//! }
47//!
48//! // View combining data pointer and vtable
49//! #[repr(C)]
50//! pub struct CounterView {
51//! pub data: *mut std::ffi::c_void,
52//! pub vtable: *const CounterVTable,
53//! }
54//!
55//! // Shim for each method
56//! unsafe extern "C" fn __counter_value_shim<T: Counter>(
57//! data: *mut c_void, argc: i32, argv: *const SEXP
58//! ) -> SEXP {
59//! // 1. Check arity
60//! // 2. Cast data to &T
61//! // 3. Call method
62//! // 4. Convert result to SEXP
63//! // 5. Catch panics
64//! }
65//!
66//! // Builder to create vtable for a concrete type
67//! pub const fn __counter_build_vtable<T: Counter>() -> CounterVTable {
68//! CounterVTable {
69//! value: __counter_value_shim::<T>,
70//! increment: __counter_increment_shim::<T>,
71//! add: __counter_add_shim::<T>,
72//! }
73//! }
74//! ```
75//!
76//! ## Supported Method Signatures
77//!
78//! Methods must follow these constraints:
79//!
80//! - **Receiver**: `&self` or `&mut self` for instance methods, or none for static methods
81//! - **Arguments**: Types that implement `TryFromSexp`
82//! - **Return**: Types that implement `IntoR`, or `()`
83//! - **No generics**: Methods cannot have generic type parameters
84//! - **No async**: Async methods are not supported
85//! - **Static methods**: Methods without a receiver are allowed and resolved at compile time
86//! (they don't go through the vtable)
87//!
88//! ## Default Methods
89//!
90//! Default method implementations are supported. The vtable builder will
91//! use the default implementation if the concrete type doesn't override it.
92//!
93//! ```ignore
94//! #[miniextendr]
95//! pub trait Counter {
96//! fn value(&self) -> i32;
97//!
98//! // Default implementation - included in vtable
99//! fn is_zero(&self) -> bool {
100//! self.value() == 0
101//! }
102//! }
103//! ```
104//!
105//! ## Error Handling
106//!
107//! Method shims handle errors as follows:
108//!
109//! - **Arity mismatch**: Raises R error ("expected N arguments, got M")
110//! - **Type conversion failure**: Raises R error with the error message
111//! - **Panic**: Caught via `with_r_unwind_protect`, converted to R error
112//!
113//! ## Thread Safety
114//!
115//! All generated shims are **main-thread only**. They do not route through
116//! `with_r_thread` because R invokes `.Call` on the main thread.
117
118use proc_macro2::TokenStream;
119use syn::ItemTrait;
120
121/// Expand `#[miniextendr]` applied to a trait definition.
122///
123/// # Arguments
124///
125/// * `attr` - Attribute arguments (currently unused, reserved for future options)
126/// * `item` - The trait definition token stream
127///
128/// # Returns
129///
130/// Expanded token stream containing:
131/// - Original trait definition
132/// - Type tag constant
133/// - Vtable struct
134/// - View struct
135/// - Method shims
136/// - Vtable builder function
137///
138/// # Errors
139///
140/// Returns a compile error if:
141/// - Methods have unsupported signatures
142/// - Methods are async
143pub fn expand_trait(
144 _attr: proc_macro::TokenStream,
145 item: proc_macro::TokenStream,
146) -> proc_macro::TokenStream {
147 let trait_item = syn::parse_macro_input!(item as ItemTrait);
148
149 // Validate trait constraints
150 if let Err(e) = validate_trait(&trait_item) {
151 return e.into_compile_error().into();
152 }
153
154 // Generate the expanded code
155 let expanded = generate_trait_abi(&trait_item);
156
157 expanded.into()
158}
159
160/// Validate that the trait meets requirements for ABI generation.
161///
162/// # Constraints
163///
164/// - All methods must have `&self` or `&mut self` receiver
165/// - Methods cannot be async
166/// - Methods cannot have generic parameters
167/// - Generic type parameters on the trait itself are allowed
168fn validate_trait(trait_item: &ItemTrait) -> syn::Result<()> {
169 let trait_name = &trait_item.ident;
170
171 // Validate each method
172 for item in &trait_item.items {
173 if let syn::TraitItem::Fn(method) = item {
174 validate_method(method, trait_name)?;
175 }
176 }
177
178 Ok(())
179}
180
181/// Validate a single trait method for ABI compatibility.
182///
183/// Rejects async methods, methods with generic type parameters, and methods
184/// that take `self` by value (only `&self` and `&mut self` are allowed).
185/// Static methods (no receiver) are permitted.
186fn validate_method(method: &syn::TraitItemFn, trait_name: &syn::Ident) -> syn::Result<()> {
187 let method_name = &method.sig.ident;
188
189 // Check for async
190 if method.sig.asyncness.is_some() {
191 return Err(syn::Error::new_spanned(
192 method.sig.asyncness,
193 format!(
194 "#[miniextendr] trait `{}::{}` cannot be async",
195 trait_name, method_name
196 ),
197 ));
198 }
199
200 // Check for generics on method
201 if !method.sig.generics.params.is_empty() {
202 return Err(syn::Error::new_spanned(
203 &method.sig.generics,
204 format!(
205 "#[miniextendr] trait method `{}::{}` cannot have generic parameters",
206 trait_name, method_name
207 ),
208 ));
209 }
210
211 // Check receiver - must be &self, &mut self, self: &Self, self: &mut Self, or no receiver
212 // Static methods are allowed but won't be included in the vtable
213 // (they're resolved at compile time via <Type as Trait>::method())
214 let receiver = method.sig.inputs.first();
215 if let Some(syn::FnArg::Receiver(r)) = receiver {
216 // Accept either:
217 // - `&self` / `&mut self` (r.reference is Some)
218 // - `self: &Self` / `self: &mut Self` (r.colon_token is Some with reference type)
219 let is_ref = if r.reference.is_some() {
220 true
221 } else if r.colon_token.is_some() {
222 // Check if the type is a reference type (&Self or &mut Self)
223 matches!(r.ty.as_ref(), syn::Type::Reference(_))
224 } else {
225 false
226 };
227
228 if !is_ref {
229 return Err(syn::Error::new_spanned(
230 r,
231 format!(
232 "#[miniextendr] trait method `{}::{}` receiver must be `&self` or `&mut self`, not `self` by value",
233 trait_name, method_name
234 ),
235 ));
236 }
237 }
238 // If receiver is None or FnArg::Typed (no self), it's a static method - allowed
239
240 Ok(())
241}
242
243/// Generate the ABI infrastructure for a trait.
244///
245/// This is the main code generation function that produces:
246/// - Type tag constant
247/// - Vtable struct
248/// - View struct (skipped for generic traits)
249/// - Method shims (with trait type params threaded through)
250/// - Vtable builder (with trait type params threaded through)
251fn generate_trait_abi(trait_item: &ItemTrait) -> TokenStream {
252 let trait_name = &trait_item.ident;
253 let vis = &trait_item.vis;
254
255 // Generate names for generated items
256 let tag_name = quote::format_ident!("TAG_{}", trait_name.to_string().to_uppercase());
257 let vtable_name = quote::format_ident!("{}VTable", trait_name);
258 let view_name = quote::format_ident!("{}View", trait_name);
259 let build_vtable_fn =
260 quote::format_ident!("__{}_build_vtable", trait_name.to_string().to_lowercase());
261
262 // Collect trait-level generic type parameters
263 let trait_type_params: Vec<&syn::GenericParam> = trait_item.generics.params.iter().collect();
264 let trait_param_idents: Vec<&syn::Ident> = trait_type_params
265 .iter()
266 .filter_map(|p| {
267 if let syn::GenericParam::Type(tp) = p {
268 Some(&tp.ident)
269 } else {
270 None
271 }
272 })
273 .collect();
274 let has_generics = !trait_param_idents.is_empty();
275
276 // Collect associated types
277 let assoc_types: Vec<&syn::Ident> = trait_item
278 .items
279 .iter()
280 .filter_map(|item| {
281 if let syn::TraitItem::Type(t) = item {
282 Some(&t.ident)
283 } else {
284 None
285 }
286 })
287 .collect();
288
289 // Collect trait where clause
290 let trait_where_clause = &trait_item.generics.where_clause;
291
292 // Collect method information
293 // Filter to only include instance methods (with &self or &mut self) that aren't skipped
294 let methods: Vec<_> = {
295 let mut collected = Vec::new();
296 for item in &trait_item.items {
297 if let syn::TraitItem::Fn(method) = item {
298 let info = match extract_method_info(method) {
299 Ok(info) => info,
300 Err(e) => return e.into_compile_error(),
301 };
302 if info.has_self && !info.skip {
303 collected.push(info);
304 }
305 }
306 }
307 collected
308 }
309 .into_iter()
310 .collect();
311
312 // Generate tag path string for hashing
313 // IMPORTANT: For cross-package trait dispatch, the tag must NOT include module_path!()
314 // Different packages defining the same trait signature should get the same tag.
315 // We use just the trait name - in practice, trait names + methods should be unique enough.
316 let tag_path = trait_name.to_string();
317
318 // Generate vtable fields
319 let vtable_fields: Vec<_> = methods
320 .iter()
321 .map(|m| {
322 let name = &m.name;
323 quote::quote! {
324 pub #name: ::miniextendr_api::abi::mx_meth
325 }
326 })
327 .collect();
328
329 // Compute extra bounds needed for shims
330 let extra_bounds =
331 compute_extra_bounds(&methods, trait_name, &assoc_types, &trait_param_idents);
332
333 // Build trait bound for __ImplT
334 let impl_t = quote::format_ident!("__ImplT");
335 let trait_bound = if trait_param_idents.is_empty() {
336 quote::quote! { #trait_name }
337 } else {
338 quote::quote! { #trait_name<#(#trait_param_idents),*> }
339 };
340
341 // Build combined where clause
342 let all_where_predicates = build_where_predicates(trait_where_clause, &extra_bounds);
343 let where_clause = if all_where_predicates.is_empty() {
344 quote::quote! {}
345 } else {
346 quote::quote! { where #(#all_where_predicates),* }
347 };
348
349 // Generate shim functions
350 let shim_fns: Vec<_> = methods
351 .iter()
352 .map(|m| {
353 generate_method_shim(
354 trait_name,
355 m,
356 &extra_bounds,
357 &trait_param_idents,
358 &trait_type_params,
359 trait_where_clause,
360 )
361 })
362 .collect();
363
364 // Generate vtable field initializers (turbofish includes trait type params)
365 let vtable_inits: Vec<_> = methods
366 .iter()
367 .map(|m| {
368 let name = &m.name;
369 let shim_name =
370 quote::format_ident!("__{}_{}_shim", trait_name.to_string().to_lowercase(), name);
371 quote::quote! {
372 #name: #shim_name::<#(#trait_param_idents,)* #impl_t>
373 }
374 })
375 .collect();
376
377 // Generate method wrappers for the View struct
378 // Skip entirely for generic traits (type-erased view can't know type params)
379 let view_methods: Vec<_> = if has_generics {
380 vec![]
381 } else {
382 methods.iter().filter_map(generate_view_method).collect()
383 };
384
385 // Strip #[miniextendr(...)] attrs from trait items before emitting
386 let mut clean_trait = trait_item.clone();
387 for item in &mut clean_trait.items {
388 if let syn::TraitItem::Fn(method) = item {
389 method
390 .attrs
391 .retain(|attr| !attr.path().is_ident("miniextendr"));
392 }
393 }
394
395 let trait_name_str = trait_name.to_string();
396 let source_loc_doc = crate::source_location_doc(trait_name.span());
397
398 let impl_bounds = &extra_bounds.impl_bounds;
399
400 // View struct and its impls (skipped for generic traits)
401 let view_tokens = if has_generics {
402 quote::quote! {}
403 } else {
404 quote::quote! {
405 #[doc = concat!(
406 "Runtime view for objects implementing `",
407 stringify!(#trait_name),
408 "`."
409 )]
410 #[doc = #source_loc_doc]
411 #[doc = concat!("Generated from source file `", file!(), "`.")]
412 ///
413 /// Combines a data pointer with a vtable pointer for method dispatch.
414 /// Use `try_from_sexp` to create a view from an R external pointer.
415 #[repr(C)]
416 #vis struct #view_name {
417 /// Pointer to the concrete object data.
418 pub data: *mut ::std::os::raw::c_void,
419 /// Pointer to the vtable for this trait.
420 pub vtable: *const #vtable_name,
421 }
422
423 // TraitView implementation
424 impl ::miniextendr_api::TraitView for #view_name {
425 const TAG: ::miniextendr_api::abi::mx_tag = #tag_name;
426
427 #[inline]
428 unsafe fn from_raw_parts(
429 data: *mut ::std::os::raw::c_void,
430 vtable: *const ::std::os::raw::c_void,
431 ) -> Self {
432 Self {
433 data,
434 vtable: vtable.cast::<#vtable_name>(),
435 }
436 }
437 }
438
439 // Method wrappers on View
440 impl #view_name {
441 /// Try to create a view from an R SEXP.
442 ///
443 /// Returns `Some(Self)` if the object implements this trait,
444 /// `None` otherwise.
445 ///
446 /// # Safety
447 ///
448 /// - `sexp` must be a valid R external pointer (EXTPTRSXP)
449 /// - Must be called on R's main thread
450 #[inline]
451 pub unsafe fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Option<Self> {
452 <Self as ::miniextendr_api::TraitView>::try_from_sexp(sexp)
453 }
454
455 /// Try to create a view, panicking with error message on failure.
456 ///
457 /// # Safety
458 ///
459 /// Same as `try_from_sexp`.
460 #[inline]
461 pub unsafe fn from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Self {
462 Self::try_from_sexp(sexp)
463 .expect(concat!("Object does not implement ", #trait_name_str, " trait"))
464 }
465
466 #(#view_methods)*
467 }
468 }
469 };
470
471 // For generic traits (with type params like <T>), skip shim and builder generation.
472 // These are generated at the impl site with concrete types to avoid recursive trait
473 // resolution overflow (e.g., `Vec<T>: TryFromSexp` triggers infinite recursion through
474 // `impl<T> TryFromSexp for Vec<Vec<T>>`).
475 let shim_and_builder = if has_generics {
476 quote::quote! {}
477 } else {
478 quote::quote! {
479 // Method shims
480 #(#shim_fns)*
481
482 #[doc = concat!(
483 "Build a vtable for a concrete type implementing `",
484 stringify!(#trait_name),
485 "`."
486 )]
487 #[doc = #source_loc_doc]
488 #[doc = concat!("Generated from source file `", file!(), "`.")]
489 #vis const fn #build_vtable_fn<#(#trait_type_params,)* #impl_t: #trait_bound #(+ #impl_bounds)*>() -> #vtable_name
490 #where_clause
491 {
492 #vtable_name {
493 #(#vtable_inits),*
494 }
495 }
496 }
497 };
498
499 // TPIE: Generate macro_rules! for non-generic traits without associated types.
500 // This enables `#[miniextendr] impl Trait for Type {}` (empty body) to auto-expand wrappers.
501 let tpie_macro = if !has_generics && assoc_types.is_empty() {
502 // Collect ALL non-skipped methods (including static) for TPIE metadata
503 let tpie_method_metadata: Vec<TokenStream> = {
504 let mut collected = Vec::new();
505 for item in &trait_item.items {
506 if let syn::TraitItem::Fn(method) = item {
507 let info = match extract_method_info(method) {
508 Ok(info) => info,
509 Err(e) => return e.into_compile_error(),
510 };
511 if !info.skip {
512 let r_name_ident = if let Some(ref rn) = info.r_name {
513 quote::format_ident!("{}", rn)
514 } else {
515 method.sig.ident.clone()
516 };
517 let sig = &method.sig;
518 collected.push(quote::quote! {
519 method { r_name = #r_name_ident; #sig; }
520 });
521 }
522 }
523 }
524 collected
525 };
526
527 let tpie_macro_name = quote::format_ident!("__mx_impl_{}", trait_name);
528 quote::quote! {
529 #[macro_export]
530 #[doc(hidden)]
531 macro_rules! #tpie_macro_name {
532 ($concrete_type:ty, $trait_path:path, $class_system:ident, $no_rd:tt, $internal:tt, $noexport:tt) => {
533 $crate::__mx_trait_impl_expand! {
534 concrete_type = $concrete_type;
535 trait_path = $trait_path;
536 class_system = $class_system;
537 no_rd = $no_rd;
538 internal = $internal;
539 noexport = $noexport;
540 #(#tpie_method_metadata)*
541 }
542 };
543 }
544 }
545 } else {
546 quote::quote! {}
547 };
548
549 quote::quote! {
550 // Pass through the original trait (with #[miniextendr] attrs stripped from items)
551 #clean_trait
552
553 #[doc = concat!(
554 "Type tag for runtime identification of the `",
555 stringify!(#trait_name),
556 "` trait."
557 )]
558 #[doc = #source_loc_doc]
559 #[doc = concat!("Generated from source file `", file!(), "`.")]
560 #vis const #tag_name: ::miniextendr_api::abi::mx_tag =
561 ::miniextendr_api::abi::mx_tag_from_path(#tag_path);
562
563 #[doc = concat!("Vtable for the `", stringify!(#trait_name), "` trait.")]
564 #[doc = #source_loc_doc]
565 #[doc = concat!("Generated from source file `", file!(), "`.")]
566 ///
567 /// Contains one `mx_meth` function pointer per trait method.
568 #[repr(C)]
569 #[doc(hidden)]
570 #vis struct #vtable_name {
571 #(#vtable_fields),*
572 }
573
574 #view_tokens
575
576 #shim_and_builder
577
578 #tpie_macro
579 }
580}
581
582/// Generate a method wrapper for the View struct.
583///
584/// This creates a method on the View that calls through the vtable.
585/// Returns None for methods with `Self` in return types or `&Self` in parameters,
586/// since these can't be meaningfully expressed on the type-erased View.
587fn generate_view_method(method: &MethodInfo) -> Option<TokenStream> {
588 // Skip methods where Self appears in return type or parameters.
589 // In the View context, Self refers to the View struct, not the concrete type,
590 // so these methods can't work through the type-erased vtable dispatch.
591 if method.return_type.as_ref().is_some_and(type_contains_self) {
592 return None;
593 }
594 if method.param_types.iter().any(type_contains_self) {
595 return None;
596 }
597
598 let method_name = &method.name;
599 let param_names = &method.param_names;
600 let param_types = &method.param_types;
601
602 // Generate function parameters
603 let params: Vec<_> = param_names
604 .iter()
605 .zip(param_types.iter())
606 .map(|(name, ty)| {
607 quote::quote! { #name: #ty }
608 })
609 .collect();
610
611 // Generate self receiver
612 let self_param = if method.is_mut {
613 quote::quote! { &mut self }
614 } else {
615 quote::quote! { &self }
616 };
617
618 // Generate argument array for vtable call
619 let argc = param_types.len() as i32;
620 let arg_conversions: Vec<_> = param_names
621 .iter()
622 .map(|name| {
623 quote::quote! {
624 ::miniextendr_api::trait_abi::to_sexp(#name)
625 }
626 })
627 .collect();
628
629 // Generate vtable call
630 let vtable_call = if argc > 0 {
631 quote::quote! {
632 let args: [::miniextendr_api::ffi::SEXP; #argc as usize] = [#(#arg_conversions),*];
633 ((*self.vtable).#method_name)(self.data, #argc, args.as_ptr())
634 }
635 } else {
636 quote::quote! {
637 ((*self.vtable).#method_name)(self.data, 0, ::std::ptr::null())
638 }
639 };
640
641 // Generate return type handling
642 let return_type = &method.return_type;
643 let (return_sig, result_conversion) = if let Some(ret_ty) = return_type {
644 (
645 quote::quote! { -> #ret_ty },
646 quote::quote! {
647 ::miniextendr_api::trait_abi::from_sexp::<#ret_ty>(result)
648 },
649 )
650 } else {
651 (
652 quote::quote! {},
653 quote::quote! {
654 let _ = result;
655 },
656 )
657 };
658
659 Some(quote::quote! {
660 #[doc = concat!("Call `", stringify!(#method_name), "` through the vtable.")]
661 #[inline]
662 pub fn #method_name(#self_param #(, #params)*) #return_sig {
663 unsafe {
664 let result = { #vtable_call };
665 // Approach 1 (issue #345): if the shim returned a tagged error SEXP,
666 // re-panic with the reconstructed RCondition so the consumer's outer
667 // `with_r_unwind_protect` guard can apply rust_* class layering.
668 ::miniextendr_api::trait_abi::repanic_if_rust_error(result);
669 #result_conversion
670 }
671 }
672 })
673}
674
675/// Generate a method shim function for a trait method.
676///
677/// The shim is an `extern "C"` function that:
678/// 1. Checks argument arity
679/// 2. Wraps everything in `with_r_unwind_protect` to prevent unwinding across FFI
680/// 3. Converts SEXP arguments to Rust types
681/// 4. Calls the actual method on the concrete type
682/// 5. Converts the result back to SEXP
683/// 6. On panic, converts to R error via `with_r_unwind_protect`
684///
685/// For generic traits, the shim carries the trait's type parameters plus `__ImplT`.
686fn generate_method_shim(
687 trait_name: &syn::Ident,
688 method: &MethodInfo,
689 extra_bounds: &ExtraBounds,
690 trait_param_idents: &[&syn::Ident],
691 trait_type_params: &[&syn::GenericParam],
692 trait_where_clause: &Option<syn::WhereClause>,
693) -> TokenStream {
694 let method_name = &method.name;
695 let shim_name = quote::format_ident!(
696 "__{}_{}_shim",
697 trait_name.to_string().to_lowercase(),
698 method_name
699 );
700 let impl_t = quote::format_ident!("__ImplT");
701
702 let param_count = method.param_types.len();
703 let expected_argc = param_count as i32;
704
705 // Generate argument extraction
706 // For &Self params, extract ExternalPtr<__ImplT> and borrow from it
707 let arg_extractions: Vec<_> = method
708 .param_names
709 .iter()
710 .zip(method.param_types.iter())
711 .enumerate()
712 .map(|(i, (name, ty))| {
713 let name_str = name.to_string();
714 let (is_self_ref, is_mut) = param_is_self_ref(ty);
715 if is_self_ref {
716 let extptr_name = quote::format_ident!("__extptr_{}", name);
717 if is_mut {
718 quote::quote! {
719 let mut #extptr_name: ::miniextendr_api::ExternalPtr<#impl_t> = unsafe {
720 ::miniextendr_api::trait_abi::extract_arg(argc, argv, #i, #name_str)
721 };
722 let #name: &mut #impl_t = &mut *#extptr_name;
723 }
724 } else {
725 quote::quote! {
726 let #extptr_name: ::miniextendr_api::ExternalPtr<#impl_t> = unsafe {
727 ::miniextendr_api::trait_abi::extract_arg(argc, argv, #i, #name_str)
728 };
729 let #name: &#impl_t = &*#extptr_name;
730 }
731 }
732 } else {
733 quote::quote! {
734 let #name: #ty = unsafe {
735 ::miniextendr_api::trait_abi::extract_arg(argc, argv, #i, #name_str)
736 };
737 }
738 }
739 })
740 .collect();
741
742 // Generate method call (uses __ImplT)
743 let param_names = &method.param_names;
744 let method_call = if method.is_mut {
745 quote::quote! {
746 let self_ref = unsafe { &mut *data.cast::<#impl_t>() };
747 self_ref.#method_name(#(#param_names),*)
748 }
749 } else {
750 quote::quote! {
751 let self_ref = unsafe { &*data.cast::<#impl_t>().cast_const() };
752 self_ref.#method_name(#(#param_names),*)
753 }
754 };
755
756 // Generate result conversion
757 let result_conversion = if method.return_type.is_some() {
758 quote::quote! {
759 unsafe { ::miniextendr_api::trait_abi::to_sexp(result) }
760 }
761 } else {
762 quote::quote! {
763 let _ = result;
764 unsafe { ::miniextendr_api::trait_abi::nil() }
765 }
766 };
767
768 // Build trait bound for __ImplT
769 let trait_bound = if trait_param_idents.is_empty() {
770 quote::quote! { #trait_name }
771 } else {
772 quote::quote! { #trait_name<#(#trait_param_idents),*> }
773 };
774
775 let impl_bounds = &extra_bounds.impl_bounds;
776 let all_where_predicates = build_where_predicates(trait_where_clause, extra_bounds);
777 let where_clause = if all_where_predicates.is_empty() {
778 quote::quote! {}
779 } else {
780 quote::quote! { where #(#all_where_predicates),* }
781 };
782
783 let method_name_str = format!("{}::{}", trait_name, method_name);
784
785 quote::quote! {
786 #[doc = concat!(
787 "Method shim for `",
788 stringify!(#trait_name),
789 "::",
790 stringify!(#method_name),
791 "`."
792 )]
793 ///
794 /// Converts SEXP arguments, calls the method, and returns SEXP result.
795 /// Both Rust panics and R longjmps are caught via `with_r_unwind_protect`.
796 #[doc(hidden)]
797 unsafe extern "C" fn #shim_name<#(#trait_type_params,)* #impl_t: #trait_bound #(+ #impl_bounds)*>(
798 data: *mut ::std::os::raw::c_void,
799 argc: i32,
800 argv: *const ::miniextendr_api::ffi::SEXP,
801 ) -> ::miniextendr_api::ffi::SEXP
802 #where_clause
803 {
804 // Check arity (before unwind protect - uses r_stop which doesn't return)
805 unsafe {
806 ::miniextendr_api::trait_abi::check_arity(argc, #expected_argc, #method_name_str);
807 }
808
809 // Wrap in with_r_unwind_protect_shim: catches Rust panics and returns
810 // a tagged error SEXP instead of calling Rf_errorcall directly. The
811 // tagged SEXP is returned to the View method wrapper which re-panics
812 // via repanic_if_rust_error, allowing the consumer's outer
813 // `with_r_unwind_protect` guard to produce rust_* class layering
814 // (issue #345). R-origin longjmps still propagate via R_ContinueUnwind.
815 ::miniextendr_api::unwind_protect::with_r_unwind_protect_shim(|| {
816 // Extract arguments
817 #(#arg_extractions)*
818
819 // Call method
820 let result = { #method_call };
821
822 // Convert result
823 #result_conversion
824 })
825 }
826 }
827}
828
829/// Information extracted from a trait method for code generation.
830///
831/// Collects everything needed to generate vtable shims, view methods,
832/// and extra trait bounds for a single method in a `#[miniextendr]` trait.
833#[derive(Debug)]
834struct MethodInfo {
835 /// Method name (Rust identifier).
836 name: syn::Ident,
837 /// Whether the method has a self receiver (instance method).
838 /// False for static/associated methods.
839 has_self: bool,
840 /// Whether receiver is `&mut self` (vs `&self`). Only meaningful if `has_self` is true.
841 is_mut: bool,
842 /// Parameter types (excluding the self receiver).
843 param_types: Vec<syn::Type>,
844 /// Parameter names (excluding the self receiver). Uses `arg{i}` for unnamed patterns.
845 param_names: Vec<syn::Ident>,
846 /// Return type. `None` when the method returns `()` (unit type or no return annotation).
847 return_type: Option<syn::Type>,
848 /// Whether method is marked `#[miniextendr(skip)]`, excluding it from codegen.
849 skip: bool,
850 /// Override the R-facing method name (from `#[miniextendr(r_name = "...")]`).
851 /// When set, R wrappers and TPIE metadata use this name instead of the Rust ident.
852 r_name: Option<String>,
853}
854
855// region: Self-type detection helpers
856
857/// Check if a type syntactically contains `Self`.
858///
859/// Used to detect when a method returns `Self` (or `Option<Self>`, `Vec<Self>`, etc.)
860/// so the generated shim can add `IntoR` bounds.
861fn type_contains_self(ty: &syn::Type) -> bool {
862 match ty {
863 syn::Type::Path(tp) => {
864 for seg in &tp.path.segments {
865 if seg.ident == "Self" {
866 return true;
867 }
868 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
869 for arg in &args.args {
870 if let syn::GenericArgument::Type(inner) = arg
871 && type_contains_self(inner)
872 {
873 return true;
874 }
875 }
876 }
877 }
878 false
879 }
880 syn::Type::Reference(r) => type_contains_self(&r.elem),
881 syn::Type::Tuple(t) => t.elems.iter().any(type_contains_self),
882 syn::Type::Slice(s) => type_contains_self(&s.elem),
883 syn::Type::Array(a) => type_contains_self(&a.elem),
884 syn::Type::Paren(p) => type_contains_self(&p.elem),
885 _ => false,
886 }
887}
888
889/// Check if a parameter type is `&Self` or `&mut Self`.
890///
891/// Returns `(is_self_ref, is_mut)`. When true, the generated shim extracts
892/// an `ExternalPtr<T>` from the SEXP and borrows from it instead of trying
893/// to extract `&T` directly (which doesn't implement `TryFromSexp`).
894fn param_is_self_ref(ty: &syn::Type) -> (bool, bool) {
895 if let syn::Type::Reference(r) = ty
896 && let syn::Type::Path(tp) = r.elem.as_ref()
897 && tp.path.is_ident("Self")
898 {
899 return (true, r.mutability.is_some());
900 }
901 (false, false)
902}
903
904/// Check if a type syntactically contains `Self::AssocType` for a given associated type name.
905///
906/// Recursively walks the type tree looking for a 2-segment path where the first
907/// segment is `Self` and the second matches `assoc_name` (e.g., `Self::Item`).
908/// Used to determine whether extra `where` bounds are needed for associated types.
909fn type_contains_self_assoc(ty: &syn::Type, assoc_name: &syn::Ident) -> bool {
910 match ty {
911 syn::Type::Path(tp) => {
912 if tp.path.segments.len() == 2
913 && tp.path.segments[0].ident == "Self"
914 && tp.path.segments[1].ident == *assoc_name
915 {
916 return true;
917 }
918 for seg in &tp.path.segments {
919 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
920 for arg in &args.args {
921 if let syn::GenericArgument::Type(inner) = arg
922 && type_contains_self_assoc(inner, assoc_name)
923 {
924 return true;
925 }
926 }
927 }
928 }
929 false
930 }
931 syn::Type::Reference(r) => type_contains_self_assoc(&r.elem, assoc_name),
932 syn::Type::Tuple(t) => t
933 .elems
934 .iter()
935 .any(|e| type_contains_self_assoc(e, assoc_name)),
936 syn::Type::Slice(s) => type_contains_self_assoc(&s.elem, assoc_name),
937 syn::Type::Array(a) => type_contains_self_assoc(&a.elem, assoc_name),
938 syn::Type::Paren(p) => type_contains_self_assoc(&p.elem, assoc_name),
939 _ => false,
940 }
941}
942
943/// Rewrite `Self` and `Self::AssocType` in a type tree to use `__ImplT`.
944///
945/// Transforms:
946/// - `Self` → `__ImplT`
947/// - `Self::Item` → `<__ImplT as TraitName>::Item`
948/// - Recursively processes generic arguments (e.g., `Option<Self::Item>` →
949/// `Option<<__ImplT as TraitName>::Item>`)
950fn rewrite_self_in_type(
951 ty: &syn::Type,
952 trait_name: &syn::Ident,
953 assoc_types: &[&syn::Ident],
954) -> syn::Type {
955 match ty {
956 syn::Type::Path(tp) => {
957 // Check for Self::AssocType (2-segment path: Self::Item)
958 if tp.path.segments.len() == 2
959 && tp.path.segments[0].ident == "Self"
960 && assoc_types.iter().any(|a| *a == &tp.path.segments[1].ident)
961 {
962 let assoc = &tp.path.segments[1].ident;
963 let impl_t = quote::format_ident!("__ImplT");
964 return syn::parse_quote!(<#impl_t as #trait_name>::#assoc);
965 }
966 // Check for bare Self
967 if tp.path.is_ident("Self") {
968 let impl_t = quote::format_ident!("__ImplT");
969 return syn::parse_quote!(#impl_t);
970 }
971 // Recursively process generic args
972 let mut new_tp = tp.clone();
973 for seg in &mut new_tp.path.segments {
974 if let syn::PathArguments::AngleBracketed(args) = &mut seg.arguments {
975 for arg in &mut args.args {
976 if let syn::GenericArgument::Type(inner) = arg {
977 *inner = rewrite_self_in_type(inner, trait_name, assoc_types);
978 }
979 }
980 }
981 }
982 syn::Type::Path(new_tp)
983 }
984 syn::Type::Reference(r) => {
985 let mut new_r = r.clone();
986 new_r.elem = Box::new(rewrite_self_in_type(&r.elem, trait_name, assoc_types));
987 syn::Type::Reference(new_r)
988 }
989 syn::Type::Tuple(t) => {
990 let mut new_t = t.clone();
991 for elem in &mut new_t.elems {
992 *elem = rewrite_self_in_type(elem, trait_name, assoc_types);
993 }
994 syn::Type::Tuple(new_t)
995 }
996 syn::Type::Slice(s) => {
997 let mut new_s = s.clone();
998 new_s.elem = Box::new(rewrite_self_in_type(&s.elem, trait_name, assoc_types));
999 syn::Type::Slice(new_s)
1000 }
1001 syn::Type::Array(a) => {
1002 let mut new_a = a.clone();
1003 new_a.elem = Box::new(rewrite_self_in_type(&a.elem, trait_name, assoc_types));
1004 syn::Type::Array(new_a)
1005 }
1006 syn::Type::Paren(p) => {
1007 let mut new_p = p.clone();
1008 new_p.elem = Box::new(rewrite_self_in_type(&p.elem, trait_name, assoc_types));
1009 syn::Type::Paren(new_p)
1010 }
1011 _ => ty.clone(),
1012 }
1013}
1014
1015/// Check if a type syntactically contains a specific identifier.
1016///
1017/// Used to detect trait type parameters (like `T`) in method signatures so that
1018/// appropriate `TryFromSexp` or `IntoR` bounds can be added. Recursively walks
1019/// through path segments, generic arguments, references, tuples, slices, and arrays.
1020fn type_contains_ident(ty: &syn::Type, ident: &syn::Ident) -> bool {
1021 match ty {
1022 syn::Type::Path(tp) => {
1023 if tp.path.segments.len() == 1 && tp.path.segments[0].ident == *ident {
1024 return true;
1025 }
1026 for seg in &tp.path.segments {
1027 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
1028 for arg in &args.args {
1029 if let syn::GenericArgument::Type(inner) = arg
1030 && type_contains_ident(inner, ident)
1031 {
1032 return true;
1033 }
1034 }
1035 }
1036 }
1037 false
1038 }
1039 syn::Type::Reference(r) => type_contains_ident(&r.elem, ident),
1040 syn::Type::Tuple(t) => t.elems.iter().any(|e| type_contains_ident(e, ident)),
1041 syn::Type::Slice(s) => type_contains_ident(&s.elem, ident),
1042 syn::Type::Array(a) => type_contains_ident(&a.elem, ident),
1043 syn::Type::Paren(p) => type_contains_ident(&p.elem, ident),
1044 _ => false,
1045 }
1046}
1047
1048/// Extra trait bounds inferred from method signatures.
1049///
1050/// For generic traits and methods that reference `Self` or associated types,
1051/// the generated shim and vtable builder functions need additional bounds
1052/// beyond `__ImplT: TraitName`. This struct collects those bounds.
1053struct ExtraBounds {
1054 /// Bounds added directly to `__ImplT` (e.g., `IntoR` when methods return `Self`,
1055 /// or `TypedExternal + Send + 'static` when methods take `&Self` parameters).
1056 impl_bounds: Vec<TokenStream>,
1057 /// Where clause predicates for complex types (e.g.,
1058 /// `<__ImplT as Trait>::Item: IntoR` or `Vec<T>: TryFromSexp`).
1059 where_predicates: Vec<TokenStream>,
1060}
1061
1062/// Compute extra bounds needed for the shim and build_vtable functions.
1063///
1064/// - Methods returning `Self` → `__ImplT: IntoR`
1065/// - Methods with `&Self` params → `__ImplT: TypedExternal + Send + 'static`
1066/// - Methods returning types with `Self::AssocType` or trait type params →
1067/// full rewritten return type `: IntoR` (e.g., `Option<<__ImplT as RIterator>::Item>: IntoR`)
1068/// - Methods with trait type params in params →
1069/// full param type `: TryFromSexp` (e.g., `Vec<T>: TryFromSexp`)
1070fn compute_extra_bounds(
1071 methods: &[MethodInfo],
1072 trait_name: &syn::Ident,
1073 assoc_types: &[&syn::Ident],
1074 trait_param_idents: &[&syn::Ident],
1075) -> ExtraBounds {
1076 let mut impl_bounds = Vec::new();
1077 let mut where_predicates = Vec::new();
1078
1079 let mut needs_into_r = false;
1080 let mut needs_typed_external = false;
1081
1082 // Track full types needing bounds (deduplicated by token string)
1083 let mut return_type_bound_keys: std::collections::BTreeMap<String, syn::Type> =
1084 Default::default();
1085 let mut param_type_bound_keys: std::collections::BTreeMap<String, syn::Type> =
1086 Default::default();
1087
1088 for method in methods {
1089 // Bare Self in returns → __ImplT: IntoR (as impl bound)
1090 if method.return_type.as_ref().is_some_and(type_contains_self) {
1091 needs_into_r = true;
1092 }
1093 // &Self in params → __ImplT: TypedExternal + 'static
1094 if method.param_types.iter().any(|ty| param_is_self_ref(ty).0) {
1095 needs_typed_external = true;
1096 }
1097
1098 // Full return type bounds for Self::AssocType and/or trait type params.
1099 // Instead of bare `<__ImplT as Trait>::Item: IntoR`, we add the FULL rewritten
1100 // return type: `Option<<__ImplT as Trait>::Item>: IntoR`, `Vec<T>: IntoR`, etc.
1101 // This is required because IntoR impls are concrete (no blanket `Option<T: IntoR>: IntoR`).
1102 if let Some(ref ret_ty) = method.return_type {
1103 let has_assoc = assoc_types
1104 .iter()
1105 .any(|a| type_contains_self_assoc(ret_ty, a));
1106 let has_param = trait_param_idents
1107 .iter()
1108 .any(|p| type_contains_ident(ret_ty, p));
1109 if has_assoc || has_param {
1110 let rewritten = rewrite_self_in_type(ret_ty, trait_name, assoc_types);
1111 let key = quote::quote!(#rewritten).to_string();
1112 return_type_bound_keys.entry(key).or_insert(rewritten);
1113 }
1114 }
1115
1116 // Full param type bounds for trait type params.
1117 // Instead of bare `T: TryFromSexp`, we add `Vec<T>: TryFromSexp` etc.
1118 for param_ty in &method.param_types {
1119 if !param_is_self_ref(param_ty).0 {
1120 let has_param = trait_param_idents
1121 .iter()
1122 .any(|p| type_contains_ident(param_ty, p));
1123 if has_param {
1124 let key = quote::quote!(#param_ty).to_string();
1125 param_type_bound_keys.entry(key).or_insert(param_ty.clone());
1126 }
1127 }
1128 }
1129 }
1130
1131 if needs_into_r {
1132 impl_bounds.push(quote::quote! { ::miniextendr_api::IntoR });
1133 }
1134 if needs_typed_external {
1135 impl_bounds.push(quote::quote! { ::miniextendr_api::TypedExternal + Send + 'static });
1136 }
1137
1138 // Add full return type bounds: RewrittenType: IntoR
1139 for ty in return_type_bound_keys.values() {
1140 where_predicates.push(quote::quote! {
1141 #ty: ::miniextendr_api::IntoR
1142 });
1143 }
1144
1145 // Add full param type bounds: ParamType: TryFromSexp, Error: Display
1146 for ty in param_type_bound_keys.values() {
1147 where_predicates.push(quote::quote! {
1148 #ty: ::miniextendr_api::TryFromSexp
1149 });
1150 where_predicates.push(quote::quote! {
1151 <#ty as ::miniextendr_api::TryFromSexp>::Error: ::std::fmt::Display
1152 });
1153 }
1154
1155 ExtraBounds {
1156 impl_bounds,
1157 where_predicates,
1158 }
1159}
1160
1161/// Build combined where predicates from the trait's own where clause and computed extra bounds.
1162///
1163/// Merges the original trait-level where clause predicates with the extra
1164/// bounds computed from method signatures (e.g., `IntoR` for return types,
1165/// `TryFromSexp` for parameters containing trait type params).
1166///
1167/// Returns a flat list of predicates suitable for use in a `where` clause.
1168fn build_where_predicates(
1169 trait_where_clause: &Option<syn::WhereClause>,
1170 extra_bounds: &ExtraBounds,
1171) -> Vec<TokenStream> {
1172 let mut all = Vec::new();
1173 if let Some(wc) = trait_where_clause {
1174 for pred in &wc.predicates {
1175 all.push(quote::quote! { #pred });
1176 }
1177 }
1178 all.extend(extra_bounds.where_predicates.iter().cloned());
1179 all
1180}
1181
1182/// Extract method information from a trait method definition.
1183///
1184/// Parses the method signature to determine receiver type, parameter names/types,
1185/// return type, and any `#[miniextendr(...)]` attributes like `skip` and `r_name`.
1186/// Parameters with non-ident patterns are assigned synthetic names (`arg0`, `arg1`, etc.).
1187fn extract_method_info(method: &syn::TraitItemFn) -> syn::Result<MethodInfo> {
1188 let name = method.sig.ident.clone();
1189
1190 // Check for #[miniextendr(skip)] and #[miniextendr(r_name = "...")]
1191 let mut skip = false;
1192 let mut r_name: Option<String> = None;
1193 for attr in &method.attrs {
1194 if !attr.path().is_ident("miniextendr") {
1195 continue;
1196 }
1197 attr.parse_nested_meta(|meta| {
1198 if meta.path.is_ident("skip") {
1199 skip = true;
1200 } else if meta.path.is_ident("r_name") {
1201 let value: syn::LitStr = meta.value()?.parse()?;
1202 r_name = Some(value.value());
1203 } else {
1204 return Err(meta.error(
1205 "unknown #[miniextendr] option on trait method; expected `skip` or `r_name`",
1206 ));
1207 }
1208 Ok(())
1209 })?;
1210 }
1211
1212 // Check for receiver
1213 let (has_self, is_mut) = method.sig.inputs.first().map_or((false, false), |arg| {
1214 if let syn::FnArg::Receiver(r) = arg {
1215 (true, r.mutability.is_some())
1216 } else {
1217 (false, false)
1218 }
1219 });
1220
1221 // Extract parameters (skip self if present)
1222 let skip_count = if has_self { 1 } else { 0 };
1223 let mut param_types = Vec::new();
1224 let mut param_names = Vec::new();
1225 for (i, arg) in method.sig.inputs.iter().skip(skip_count).enumerate() {
1226 if let syn::FnArg::Typed(pat_type) = arg {
1227 param_types.push((*pat_type.ty).clone());
1228 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1229 param_names.push(pat_ident.ident.clone());
1230 } else {
1231 param_names.push(quote::format_ident!("arg{}", i));
1232 }
1233 }
1234 }
1235
1236 // Extract return type
1237 let return_type = match &method.sig.output {
1238 syn::ReturnType::Default => None,
1239 syn::ReturnType::Type(_, ty) => {
1240 // Check if it's unit type ()
1241 if matches!(ty.as_ref(), syn::Type::Tuple(t) if t.elems.is_empty()) {
1242 None
1243 } else {
1244 Some((**ty).clone())
1245 }
1246 }
1247 };
1248
1249 Ok(MethodInfo {
1250 name,
1251 has_self,
1252 is_mut,
1253 param_types,
1254 param_names,
1255 return_type,
1256 skip,
1257 r_name,
1258 })
1259}
1260
1261#[cfg(test)]
1262mod tests;
1263// endregion