Skip to main content

miniextendr_macros/
miniextendr_trait.rs

1//! # Trait Support for `#[miniextendr]`
2//!
3//! This module handles `#[miniextendr]` applied to trait definitions,
4//! generating the ABI infrastructure for cross-package trait dispatch.
5//!
6//! ## Overview
7//!
8//! When `#[miniextendr]` is applied to a trait, it generates:
9//!
10//! 1. **Type tag constant** (`TAG_<TraitName>`) - 128-bit identifier for runtime type checking
11//! 2. **Vtable struct** (`<TraitName>VTable`) - Function pointer table for method dispatch
12//! 3. **View struct** (`<TraitName>View`) - Runtime wrapper combining data pointer and vtable
13//! 4. **Method shims** - `extern "C"` functions that convert SEXP arguments and call methods
14//! 5. **Vtable builder** - `__<trait>_build_vtable::<T>()` for impl blocks
15//!
16//! ## Usage
17//!
18//! ```ignore
19//! #[miniextendr]
20//! pub trait Counter {
21//!     fn value(&self) -> i32;
22//!     fn increment(&mut self);
23//!     fn add(&mut self, n: i32);
24//! }
25//! ```
26//!
27//! Generates (conceptually):
28//!
29//! ```text
30//! // Original trait (passed through)
31//! pub trait Counter {
32//!     fn value(&self) -> i32;
33//!     fn increment(&mut self);
34//!     fn add(&mut self, n: i32);
35//! }
36//!
37//! // Type tag for runtime identification
38//! pub const TAG_COUNTER: mx_tag = mx_tag::new(0x..., 0x...);
39//!
40//! // Vtable with one entry per method
41//! #[repr(C)]
42//! pub struct CounterVTable {
43//!     pub value: mx_meth,
44//!     pub increment: mx_meth,
45//!     pub add: mx_meth,
46//! }
47//!
48//! // View combining data pointer and vtable
49//! #[repr(C)]
50//! pub struct CounterView {
51//!     pub data: *mut std::ffi::c_void,
52//!     pub vtable: *const CounterVTable,
53//! }
54//!
55//! // Shim for each method
56//! unsafe extern "C" fn __counter_value_shim<T: Counter>(
57//!     data: *mut c_void, argc: i32, argv: *const SEXP
58//! ) -> SEXP {
59//!     // 1. Check arity
60//!     // 2. Cast data to &T
61//!     // 3. Call method
62//!     // 4. Convert result to SEXP
63//!     // 5. Catch panics
64//! }
65//!
66//! // Builder to create vtable for a concrete type
67//! pub const fn __counter_build_vtable<T: Counter>() -> CounterVTable {
68//!     CounterVTable {
69//!         value: __counter_value_shim::<T>,
70//!         increment: __counter_increment_shim::<T>,
71//!         add: __counter_add_shim::<T>,
72//!     }
73//! }
74//! ```
75//!
76//! ## Supported Method Signatures
77//!
78//! Methods must follow these constraints:
79//!
80//! - **Receiver**: `&self` or `&mut self` for instance methods, or none for static methods
81//! - **Arguments**: Types that implement `TryFromSexp`
82//! - **Return**: Types that implement `IntoR`, or `()`
83//! - **No generics**: Methods cannot have generic type parameters
84//! - **No async**: Async methods are not supported
85//! - **Static methods**: Methods without a receiver are allowed and resolved at compile time
86//!   (they don't go through the vtable)
87//!
88//! ## Default Methods
89//!
90//! Default method implementations are supported. The vtable builder will
91//! use the default implementation if the concrete type doesn't override it.
92//!
93//! ```ignore
94//! #[miniextendr]
95//! pub trait Counter {
96//!     fn value(&self) -> i32;
97//!
98//!     // Default implementation - included in vtable
99//!     fn is_zero(&self) -> bool {
100//!         self.value() == 0
101//!     }
102//! }
103//! ```
104//!
105//! ## Error Handling
106//!
107//! Method shims handle errors as follows:
108//!
109//! - **Arity mismatch**: Raises R error ("expected N arguments, got M")
110//! - **Type conversion failure**: Raises R error with the error message
111//! - **Panic**: Caught via `with_r_unwind_protect`, converted to R error
112//!
113//! ## Thread Safety
114//!
115//! All generated shims are **main-thread only**. They do not route through
116//! `with_r_thread` because R invokes `.Call` on the main thread.
117
118use proc_macro2::TokenStream;
119use syn::ItemTrait;
120
121/// Expand `#[miniextendr]` applied to a trait definition.
122///
123/// # Arguments
124///
125/// * `attr` - Attribute arguments (currently unused, reserved for future options)
126/// * `item` - The trait definition token stream
127///
128/// # Returns
129///
130/// Expanded token stream containing:
131/// - Original trait definition
132/// - Type tag constant
133/// - Vtable struct
134/// - View struct
135/// - Method shims
136/// - Vtable builder function
137///
138/// # Errors
139///
140/// Returns a compile error if:
141/// - Methods have unsupported signatures
142/// - Methods are async
143pub fn expand_trait(
144    _attr: proc_macro::TokenStream,
145    item: proc_macro::TokenStream,
146) -> proc_macro::TokenStream {
147    let trait_item = syn::parse_macro_input!(item as ItemTrait);
148
149    // Validate trait constraints
150    if let Err(e) = validate_trait(&trait_item) {
151        return e.into_compile_error().into();
152    }
153
154    // Generate the expanded code
155    let expanded = generate_trait_abi(&trait_item);
156
157    expanded.into()
158}
159
160/// Validate that the trait meets requirements for ABI generation.
161///
162/// # Constraints
163///
164/// - All methods must have `&self` or `&mut self` receiver
165/// - Methods cannot be async
166/// - Methods cannot have generic parameters
167/// - Generic type parameters on the trait itself are allowed
168fn validate_trait(trait_item: &ItemTrait) -> syn::Result<()> {
169    let trait_name = &trait_item.ident;
170
171    // Validate each method
172    for item in &trait_item.items {
173        if let syn::TraitItem::Fn(method) = item {
174            validate_method(method, trait_name)?;
175        }
176    }
177
178    Ok(())
179}
180
181/// Validate a single trait method for ABI compatibility.
182///
183/// Rejects async methods, methods with generic type parameters, and methods
184/// that take `self` by value (only `&self` and `&mut self` are allowed).
185/// Static methods (no receiver) are permitted.
186fn validate_method(method: &syn::TraitItemFn, trait_name: &syn::Ident) -> syn::Result<()> {
187    let method_name = &method.sig.ident;
188
189    // Check for async
190    if method.sig.asyncness.is_some() {
191        return Err(syn::Error::new_spanned(
192            method.sig.asyncness,
193            format!(
194                "#[miniextendr] trait `{}::{}` cannot be async",
195                trait_name, method_name
196            ),
197        ));
198    }
199
200    // Check for generics on method
201    if !method.sig.generics.params.is_empty() {
202        return Err(syn::Error::new_spanned(
203            &method.sig.generics,
204            format!(
205                "#[miniextendr] trait method `{}::{}` cannot have generic parameters",
206                trait_name, method_name
207            ),
208        ));
209    }
210
211    // Check receiver - must be &self, &mut self, self: &Self, self: &mut Self, or no receiver
212    // Static methods are allowed but won't be included in the vtable
213    // (they're resolved at compile time via <Type as Trait>::method())
214    let receiver = method.sig.inputs.first();
215    if let Some(syn::FnArg::Receiver(r)) = receiver {
216        // Accept either:
217        // - `&self` / `&mut self` (r.reference is Some)
218        // - `self: &Self` / `self: &mut Self` (r.colon_token is Some with reference type)
219        let is_ref = if r.reference.is_some() {
220            true
221        } else if r.colon_token.is_some() {
222            // Check if the type is a reference type (&Self or &mut Self)
223            matches!(r.ty.as_ref(), syn::Type::Reference(_))
224        } else {
225            false
226        };
227
228        if !is_ref {
229            return Err(syn::Error::new_spanned(
230                r,
231                format!(
232                    "#[miniextendr] trait method `{}::{}` receiver must be `&self` or `&mut self`, not `self` by value",
233                    trait_name, method_name
234                ),
235            ));
236        }
237    }
238    // If receiver is None or FnArg::Typed (no self), it's a static method - allowed
239
240    Ok(())
241}
242
243/// Generate the ABI infrastructure for a trait.
244///
245/// This is the main code generation function that produces:
246/// - Type tag constant
247/// - Vtable struct
248/// - View struct (skipped for generic traits)
249/// - Method shims (with trait type params threaded through)
250/// - Vtable builder (with trait type params threaded through)
251fn generate_trait_abi(trait_item: &ItemTrait) -> TokenStream {
252    let trait_name = &trait_item.ident;
253    let vis = &trait_item.vis;
254
255    // Generate names for generated items
256    let tag_name = quote::format_ident!("TAG_{}", trait_name.to_string().to_uppercase());
257    let vtable_name = quote::format_ident!("{}VTable", trait_name);
258    let view_name = quote::format_ident!("{}View", trait_name);
259    let build_vtable_fn =
260        quote::format_ident!("__{}_build_vtable", trait_name.to_string().to_lowercase());
261
262    // Collect trait-level generic type parameters
263    let trait_type_params: Vec<&syn::GenericParam> = trait_item.generics.params.iter().collect();
264    let trait_param_idents: Vec<&syn::Ident> = trait_type_params
265        .iter()
266        .filter_map(|p| {
267            if let syn::GenericParam::Type(tp) = p {
268                Some(&tp.ident)
269            } else {
270                None
271            }
272        })
273        .collect();
274    let has_generics = !trait_param_idents.is_empty();
275
276    // Collect associated types
277    let assoc_types: Vec<&syn::Ident> = trait_item
278        .items
279        .iter()
280        .filter_map(|item| {
281            if let syn::TraitItem::Type(t) = item {
282                Some(&t.ident)
283            } else {
284                None
285            }
286        })
287        .collect();
288
289    // Collect trait where clause
290    let trait_where_clause = &trait_item.generics.where_clause;
291
292    // Collect method information
293    // Filter to only include instance methods (with &self or &mut self) that aren't skipped
294    let methods: Vec<_> = {
295        let mut collected = Vec::new();
296        for item in &trait_item.items {
297            if let syn::TraitItem::Fn(method) = item {
298                let info = match extract_method_info(method) {
299                    Ok(info) => info,
300                    Err(e) => return e.into_compile_error(),
301                };
302                if info.has_self && !info.skip {
303                    collected.push(info);
304                }
305            }
306        }
307        collected
308    }
309    .into_iter()
310    .collect();
311
312    // Generate tag path string for hashing
313    // IMPORTANT: For cross-package trait dispatch, the tag must NOT include module_path!()
314    // Different packages defining the same trait signature should get the same tag.
315    // We use just the trait name - in practice, trait names + methods should be unique enough.
316    let tag_path = trait_name.to_string();
317
318    // Generate vtable fields
319    let vtable_fields: Vec<_> = methods
320        .iter()
321        .map(|m| {
322            let name = &m.name;
323            quote::quote! {
324                pub #name: ::miniextendr_api::abi::mx_meth
325            }
326        })
327        .collect();
328
329    // Compute extra bounds needed for shims
330    let extra_bounds =
331        compute_extra_bounds(&methods, trait_name, &assoc_types, &trait_param_idents);
332
333    // Build trait bound for __ImplT
334    let impl_t = quote::format_ident!("__ImplT");
335    let trait_bound = if trait_param_idents.is_empty() {
336        quote::quote! { #trait_name }
337    } else {
338        quote::quote! { #trait_name<#(#trait_param_idents),*> }
339    };
340
341    // Build combined where clause
342    let all_where_predicates = build_where_predicates(trait_where_clause, &extra_bounds);
343    let where_clause = if all_where_predicates.is_empty() {
344        quote::quote! {}
345    } else {
346        quote::quote! { where #(#all_where_predicates),* }
347    };
348
349    // Generate shim functions
350    let shim_fns: Vec<_> = methods
351        .iter()
352        .map(|m| {
353            generate_method_shim(
354                trait_name,
355                m,
356                &extra_bounds,
357                &trait_param_idents,
358                &trait_type_params,
359                trait_where_clause,
360            )
361        })
362        .collect();
363
364    // Generate vtable field initializers (turbofish includes trait type params)
365    let vtable_inits: Vec<_> = methods
366        .iter()
367        .map(|m| {
368            let name = &m.name;
369            let shim_name =
370                quote::format_ident!("__{}_{}_shim", trait_name.to_string().to_lowercase(), name);
371            quote::quote! {
372                #name: #shim_name::<#(#trait_param_idents,)* #impl_t>
373            }
374        })
375        .collect();
376
377    // Generate method wrappers for the View struct
378    // Skip entirely for generic traits (type-erased view can't know type params)
379    let view_methods: Vec<_> = if has_generics {
380        vec![]
381    } else {
382        methods.iter().filter_map(generate_view_method).collect()
383    };
384
385    // Strip #[miniextendr(...)] attrs from trait items before emitting
386    let mut clean_trait = trait_item.clone();
387    for item in &mut clean_trait.items {
388        if let syn::TraitItem::Fn(method) = item {
389            method
390                .attrs
391                .retain(|attr| !attr.path().is_ident("miniextendr"));
392        }
393    }
394
395    let trait_name_str = trait_name.to_string();
396    let source_loc_doc = crate::source_location_doc(trait_name.span());
397
398    let impl_bounds = &extra_bounds.impl_bounds;
399
400    // View struct and its impls (skipped for generic traits)
401    let view_tokens = if has_generics {
402        quote::quote! {}
403    } else {
404        quote::quote! {
405            #[doc = concat!(
406                "Runtime view for objects implementing `",
407                stringify!(#trait_name),
408                "`."
409            )]
410            #[doc = #source_loc_doc]
411            #[doc = concat!("Generated from source file `", file!(), "`.")]
412            ///
413            /// Combines a data pointer with a vtable pointer for method dispatch.
414            /// Use `try_from_sexp` to create a view from an R external pointer.
415            #[repr(C)]
416            #vis struct #view_name {
417                /// Pointer to the concrete object data.
418                pub data: *mut ::std::os::raw::c_void,
419                /// Pointer to the vtable for this trait.
420                pub vtable: *const #vtable_name,
421            }
422
423            // TraitView implementation
424            impl ::miniextendr_api::TraitView for #view_name {
425                const TAG: ::miniextendr_api::abi::mx_tag = #tag_name;
426
427                #[inline]
428                unsafe fn from_raw_parts(
429                    data: *mut ::std::os::raw::c_void,
430                    vtable: *const ::std::os::raw::c_void,
431                ) -> Self {
432                    Self {
433                        data,
434                        vtable: vtable.cast::<#vtable_name>(),
435                    }
436                }
437            }
438
439            // Method wrappers on View
440            impl #view_name {
441                /// Try to create a view from an R SEXP.
442                ///
443                /// Returns `Some(Self)` if the object implements this trait,
444                /// `None` otherwise.
445                ///
446                /// # Safety
447                ///
448                /// - `sexp` must be a valid R external pointer (EXTPTRSXP)
449                /// - Must be called on R's main thread
450                #[inline]
451                pub unsafe fn try_from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Option<Self> {
452                    <Self as ::miniextendr_api::TraitView>::try_from_sexp(sexp)
453                }
454
455                /// Try to create a view, panicking with error message on failure.
456                ///
457                /// # Safety
458                ///
459                /// Same as `try_from_sexp`.
460                #[inline]
461                pub unsafe fn from_sexp(sexp: ::miniextendr_api::ffi::SEXP) -> Self {
462                    Self::try_from_sexp(sexp)
463                        .expect(concat!("Object does not implement ", #trait_name_str, " trait"))
464                }
465
466                #(#view_methods)*
467            }
468        }
469    };
470
471    // For generic traits (with type params like <T>), skip shim and builder generation.
472    // These are generated at the impl site with concrete types to avoid recursive trait
473    // resolution overflow (e.g., `Vec<T>: TryFromSexp` triggers infinite recursion through
474    // `impl<T> TryFromSexp for Vec<Vec<T>>`).
475    let shim_and_builder = if has_generics {
476        quote::quote! {}
477    } else {
478        quote::quote! {
479            // Method shims
480            #(#shim_fns)*
481
482            #[doc = concat!(
483                "Build a vtable for a concrete type implementing `",
484                stringify!(#trait_name),
485                "`."
486            )]
487            #[doc = #source_loc_doc]
488            #[doc = concat!("Generated from source file `", file!(), "`.")]
489            #vis const fn #build_vtable_fn<#(#trait_type_params,)* #impl_t: #trait_bound #(+ #impl_bounds)*>() -> #vtable_name
490            #where_clause
491            {
492                #vtable_name {
493                    #(#vtable_inits),*
494                }
495            }
496        }
497    };
498
499    // TPIE: Generate macro_rules! for non-generic traits without associated types.
500    // This enables `#[miniextendr] impl Trait for Type {}` (empty body) to auto-expand wrappers.
501    let tpie_macro = if !has_generics && assoc_types.is_empty() {
502        // Collect ALL non-skipped methods (including static) for TPIE metadata
503        let tpie_method_metadata: Vec<TokenStream> = {
504            let mut collected = Vec::new();
505            for item in &trait_item.items {
506                if let syn::TraitItem::Fn(method) = item {
507                    let info = match extract_method_info(method) {
508                        Ok(info) => info,
509                        Err(e) => return e.into_compile_error(),
510                    };
511                    if !info.skip {
512                        let r_name_ident = if let Some(ref rn) = info.r_name {
513                            quote::format_ident!("{}", rn)
514                        } else {
515                            method.sig.ident.clone()
516                        };
517                        let sig = &method.sig;
518                        collected.push(quote::quote! {
519                            method { r_name = #r_name_ident; #sig; }
520                        });
521                    }
522                }
523            }
524            collected
525        };
526
527        let tpie_macro_name = quote::format_ident!("__mx_impl_{}", trait_name);
528        quote::quote! {
529            #[macro_export]
530            #[doc(hidden)]
531            macro_rules! #tpie_macro_name {
532                ($concrete_type:ty, $trait_path:path, $class_system:ident, $no_rd:tt, $internal:tt, $noexport:tt) => {
533                    $crate::__mx_trait_impl_expand! {
534                        concrete_type = $concrete_type;
535                        trait_path = $trait_path;
536                        class_system = $class_system;
537                        no_rd = $no_rd;
538                        internal = $internal;
539                        noexport = $noexport;
540                        #(#tpie_method_metadata)*
541                    }
542                };
543            }
544        }
545    } else {
546        quote::quote! {}
547    };
548
549    quote::quote! {
550        // Pass through the original trait (with #[miniextendr] attrs stripped from items)
551        #clean_trait
552
553        #[doc = concat!(
554            "Type tag for runtime identification of the `",
555            stringify!(#trait_name),
556            "` trait."
557        )]
558        #[doc = #source_loc_doc]
559        #[doc = concat!("Generated from source file `", file!(), "`.")]
560        #vis const #tag_name: ::miniextendr_api::abi::mx_tag =
561            ::miniextendr_api::abi::mx_tag_from_path(#tag_path);
562
563        #[doc = concat!("Vtable for the `", stringify!(#trait_name), "` trait.")]
564        #[doc = #source_loc_doc]
565        #[doc = concat!("Generated from source file `", file!(), "`.")]
566        ///
567        /// Contains one `mx_meth` function pointer per trait method.
568        #[repr(C)]
569        #[doc(hidden)]
570        #vis struct #vtable_name {
571            #(#vtable_fields),*
572        }
573
574        #view_tokens
575
576        #shim_and_builder
577
578        #tpie_macro
579    }
580}
581
582/// Generate a method wrapper for the View struct.
583///
584/// This creates a method on the View that calls through the vtable.
585/// Returns None for methods with `Self` in return types or `&Self` in parameters,
586/// since these can't be meaningfully expressed on the type-erased View.
587fn generate_view_method(method: &MethodInfo) -> Option<TokenStream> {
588    // Skip methods where Self appears in return type or parameters.
589    // In the View context, Self refers to the View struct, not the concrete type,
590    // so these methods can't work through the type-erased vtable dispatch.
591    if method.return_type.as_ref().is_some_and(type_contains_self) {
592        return None;
593    }
594    if method.param_types.iter().any(type_contains_self) {
595        return None;
596    }
597
598    let method_name = &method.name;
599    let param_names = &method.param_names;
600    let param_types = &method.param_types;
601
602    // Generate function parameters
603    let params: Vec<_> = param_names
604        .iter()
605        .zip(param_types.iter())
606        .map(|(name, ty)| {
607            quote::quote! { #name: #ty }
608        })
609        .collect();
610
611    // Generate self receiver
612    let self_param = if method.is_mut {
613        quote::quote! { &mut self }
614    } else {
615        quote::quote! { &self }
616    };
617
618    // Generate argument array for vtable call
619    let argc = param_types.len() as i32;
620    let arg_conversions: Vec<_> = param_names
621        .iter()
622        .map(|name| {
623            quote::quote! {
624                ::miniextendr_api::trait_abi::to_sexp(#name)
625            }
626        })
627        .collect();
628
629    // Generate vtable call
630    let vtable_call = if argc > 0 {
631        quote::quote! {
632            let args: [::miniextendr_api::ffi::SEXP; #argc as usize] = [#(#arg_conversions),*];
633            ((*self.vtable).#method_name)(self.data, #argc, args.as_ptr())
634        }
635    } else {
636        quote::quote! {
637            ((*self.vtable).#method_name)(self.data, 0, ::std::ptr::null())
638        }
639    };
640
641    // Generate return type handling
642    let return_type = &method.return_type;
643    let (return_sig, result_conversion) = if let Some(ret_ty) = return_type {
644        (
645            quote::quote! { -> #ret_ty },
646            quote::quote! {
647                ::miniextendr_api::trait_abi::from_sexp::<#ret_ty>(result)
648            },
649        )
650    } else {
651        (
652            quote::quote! {},
653            quote::quote! {
654                let _ = result;
655            },
656        )
657    };
658
659    Some(quote::quote! {
660        #[doc = concat!("Call `", stringify!(#method_name), "` through the vtable.")]
661        #[inline]
662        pub fn #method_name(#self_param #(, #params)*) #return_sig {
663            unsafe {
664                let result = { #vtable_call };
665                // Approach 1 (issue #345): if the shim returned a tagged error SEXP,
666                // re-panic with the reconstructed RCondition so the consumer's outer
667                // `with_r_unwind_protect` guard can apply rust_* class layering.
668                ::miniextendr_api::trait_abi::repanic_if_rust_error(result);
669                #result_conversion
670            }
671        }
672    })
673}
674
675/// Generate a method shim function for a trait method.
676///
677/// The shim is an `extern "C"` function that:
678/// 1. Checks argument arity
679/// 2. Wraps everything in `with_r_unwind_protect` to prevent unwinding across FFI
680/// 3. Converts SEXP arguments to Rust types
681/// 4. Calls the actual method on the concrete type
682/// 5. Converts the result back to SEXP
683/// 6. On panic, converts to R error via `with_r_unwind_protect`
684///
685/// For generic traits, the shim carries the trait's type parameters plus `__ImplT`.
686fn generate_method_shim(
687    trait_name: &syn::Ident,
688    method: &MethodInfo,
689    extra_bounds: &ExtraBounds,
690    trait_param_idents: &[&syn::Ident],
691    trait_type_params: &[&syn::GenericParam],
692    trait_where_clause: &Option<syn::WhereClause>,
693) -> TokenStream {
694    let method_name = &method.name;
695    let shim_name = quote::format_ident!(
696        "__{}_{}_shim",
697        trait_name.to_string().to_lowercase(),
698        method_name
699    );
700    let impl_t = quote::format_ident!("__ImplT");
701
702    let param_count = method.param_types.len();
703    let expected_argc = param_count as i32;
704
705    // Generate argument extraction
706    // For &Self params, extract ExternalPtr<__ImplT> and borrow from it
707    let arg_extractions: Vec<_> = method
708        .param_names
709        .iter()
710        .zip(method.param_types.iter())
711        .enumerate()
712        .map(|(i, (name, ty))| {
713            let name_str = name.to_string();
714            let (is_self_ref, is_mut) = param_is_self_ref(ty);
715            if is_self_ref {
716                let extptr_name = quote::format_ident!("__extptr_{}", name);
717                if is_mut {
718                    quote::quote! {
719                        let mut #extptr_name: ::miniextendr_api::ExternalPtr<#impl_t> = unsafe {
720                            ::miniextendr_api::trait_abi::extract_arg(argc, argv, #i, #name_str)
721                        };
722                        let #name: &mut #impl_t = &mut *#extptr_name;
723                    }
724                } else {
725                    quote::quote! {
726                        let #extptr_name: ::miniextendr_api::ExternalPtr<#impl_t> = unsafe {
727                            ::miniextendr_api::trait_abi::extract_arg(argc, argv, #i, #name_str)
728                        };
729                        let #name: &#impl_t = &*#extptr_name;
730                    }
731                }
732            } else {
733                quote::quote! {
734                    let #name: #ty = unsafe {
735                        ::miniextendr_api::trait_abi::extract_arg(argc, argv, #i, #name_str)
736                    };
737                }
738            }
739        })
740        .collect();
741
742    // Generate method call (uses __ImplT)
743    let param_names = &method.param_names;
744    let method_call = if method.is_mut {
745        quote::quote! {
746            let self_ref = unsafe { &mut *data.cast::<#impl_t>() };
747            self_ref.#method_name(#(#param_names),*)
748        }
749    } else {
750        quote::quote! {
751            let self_ref = unsafe { &*data.cast::<#impl_t>().cast_const() };
752            self_ref.#method_name(#(#param_names),*)
753        }
754    };
755
756    // Generate result conversion
757    let result_conversion = if method.return_type.is_some() {
758        quote::quote! {
759            unsafe { ::miniextendr_api::trait_abi::to_sexp(result) }
760        }
761    } else {
762        quote::quote! {
763            let _ = result;
764            unsafe { ::miniextendr_api::trait_abi::nil() }
765        }
766    };
767
768    // Build trait bound for __ImplT
769    let trait_bound = if trait_param_idents.is_empty() {
770        quote::quote! { #trait_name }
771    } else {
772        quote::quote! { #trait_name<#(#trait_param_idents),*> }
773    };
774
775    let impl_bounds = &extra_bounds.impl_bounds;
776    let all_where_predicates = build_where_predicates(trait_where_clause, extra_bounds);
777    let where_clause = if all_where_predicates.is_empty() {
778        quote::quote! {}
779    } else {
780        quote::quote! { where #(#all_where_predicates),* }
781    };
782
783    let method_name_str = format!("{}::{}", trait_name, method_name);
784
785    quote::quote! {
786        #[doc = concat!(
787            "Method shim for `",
788            stringify!(#trait_name),
789            "::",
790            stringify!(#method_name),
791            "`."
792        )]
793        ///
794        /// Converts SEXP arguments, calls the method, and returns SEXP result.
795        /// Both Rust panics and R longjmps are caught via `with_r_unwind_protect`.
796        #[doc(hidden)]
797        unsafe extern "C" fn #shim_name<#(#trait_type_params,)* #impl_t: #trait_bound #(+ #impl_bounds)*>(
798            data: *mut ::std::os::raw::c_void,
799            argc: i32,
800            argv: *const ::miniextendr_api::ffi::SEXP,
801        ) -> ::miniextendr_api::ffi::SEXP
802        #where_clause
803        {
804            // Check arity (before unwind protect - uses r_stop which doesn't return)
805            unsafe {
806                ::miniextendr_api::trait_abi::check_arity(argc, #expected_argc, #method_name_str);
807            }
808
809            // Wrap in with_r_unwind_protect_shim: catches Rust panics and returns
810            // a tagged error SEXP instead of calling Rf_errorcall directly. The
811            // tagged SEXP is returned to the View method wrapper which re-panics
812            // via repanic_if_rust_error, allowing the consumer's outer
813            // `with_r_unwind_protect` guard to produce rust_* class layering
814            // (issue #345). R-origin longjmps still propagate via R_ContinueUnwind.
815            ::miniextendr_api::unwind_protect::with_r_unwind_protect_shim(|| {
816                // Extract arguments
817                #(#arg_extractions)*
818
819                // Call method
820                let result = { #method_call };
821
822                // Convert result
823                #result_conversion
824            })
825        }
826    }
827}
828
829/// Information extracted from a trait method for code generation.
830///
831/// Collects everything needed to generate vtable shims, view methods,
832/// and extra trait bounds for a single method in a `#[miniextendr]` trait.
833#[derive(Debug)]
834struct MethodInfo {
835    /// Method name (Rust identifier).
836    name: syn::Ident,
837    /// Whether the method has a self receiver (instance method).
838    /// False for static/associated methods.
839    has_self: bool,
840    /// Whether receiver is `&mut self` (vs `&self`). Only meaningful if `has_self` is true.
841    is_mut: bool,
842    /// Parameter types (excluding the self receiver).
843    param_types: Vec<syn::Type>,
844    /// Parameter names (excluding the self receiver). Uses `arg{i}` for unnamed patterns.
845    param_names: Vec<syn::Ident>,
846    /// Return type. `None` when the method returns `()` (unit type or no return annotation).
847    return_type: Option<syn::Type>,
848    /// Whether method is marked `#[miniextendr(skip)]`, excluding it from codegen.
849    skip: bool,
850    /// Override the R-facing method name (from `#[miniextendr(r_name = "...")]`).
851    /// When set, R wrappers and TPIE metadata use this name instead of the Rust ident.
852    r_name: Option<String>,
853}
854
855// region: Self-type detection helpers
856
857/// Check if a type syntactically contains `Self`.
858///
859/// Used to detect when a method returns `Self` (or `Option<Self>`, `Vec<Self>`, etc.)
860/// so the generated shim can add `IntoR` bounds.
861fn type_contains_self(ty: &syn::Type) -> bool {
862    match ty {
863        syn::Type::Path(tp) => {
864            for seg in &tp.path.segments {
865                if seg.ident == "Self" {
866                    return true;
867                }
868                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
869                    for arg in &args.args {
870                        if let syn::GenericArgument::Type(inner) = arg
871                            && type_contains_self(inner)
872                        {
873                            return true;
874                        }
875                    }
876                }
877            }
878            false
879        }
880        syn::Type::Reference(r) => type_contains_self(&r.elem),
881        syn::Type::Tuple(t) => t.elems.iter().any(type_contains_self),
882        syn::Type::Slice(s) => type_contains_self(&s.elem),
883        syn::Type::Array(a) => type_contains_self(&a.elem),
884        syn::Type::Paren(p) => type_contains_self(&p.elem),
885        _ => false,
886    }
887}
888
889/// Check if a parameter type is `&Self` or `&mut Self`.
890///
891/// Returns `(is_self_ref, is_mut)`. When true, the generated shim extracts
892/// an `ExternalPtr<T>` from the SEXP and borrows from it instead of trying
893/// to extract `&T` directly (which doesn't implement `TryFromSexp`).
894fn param_is_self_ref(ty: &syn::Type) -> (bool, bool) {
895    if let syn::Type::Reference(r) = ty
896        && let syn::Type::Path(tp) = r.elem.as_ref()
897        && tp.path.is_ident("Self")
898    {
899        return (true, r.mutability.is_some());
900    }
901    (false, false)
902}
903
904/// Check if a type syntactically contains `Self::AssocType` for a given associated type name.
905///
906/// Recursively walks the type tree looking for a 2-segment path where the first
907/// segment is `Self` and the second matches `assoc_name` (e.g., `Self::Item`).
908/// Used to determine whether extra `where` bounds are needed for associated types.
909fn type_contains_self_assoc(ty: &syn::Type, assoc_name: &syn::Ident) -> bool {
910    match ty {
911        syn::Type::Path(tp) => {
912            if tp.path.segments.len() == 2
913                && tp.path.segments[0].ident == "Self"
914                && tp.path.segments[1].ident == *assoc_name
915            {
916                return true;
917            }
918            for seg in &tp.path.segments {
919                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
920                    for arg in &args.args {
921                        if let syn::GenericArgument::Type(inner) = arg
922                            && type_contains_self_assoc(inner, assoc_name)
923                        {
924                            return true;
925                        }
926                    }
927                }
928            }
929            false
930        }
931        syn::Type::Reference(r) => type_contains_self_assoc(&r.elem, assoc_name),
932        syn::Type::Tuple(t) => t
933            .elems
934            .iter()
935            .any(|e| type_contains_self_assoc(e, assoc_name)),
936        syn::Type::Slice(s) => type_contains_self_assoc(&s.elem, assoc_name),
937        syn::Type::Array(a) => type_contains_self_assoc(&a.elem, assoc_name),
938        syn::Type::Paren(p) => type_contains_self_assoc(&p.elem, assoc_name),
939        _ => false,
940    }
941}
942
943/// Rewrite `Self` and `Self::AssocType` in a type tree to use `__ImplT`.
944///
945/// Transforms:
946/// - `Self` → `__ImplT`
947/// - `Self::Item` → `<__ImplT as TraitName>::Item`
948/// - Recursively processes generic arguments (e.g., `Option<Self::Item>` →
949///   `Option<<__ImplT as TraitName>::Item>`)
950fn rewrite_self_in_type(
951    ty: &syn::Type,
952    trait_name: &syn::Ident,
953    assoc_types: &[&syn::Ident],
954) -> syn::Type {
955    match ty {
956        syn::Type::Path(tp) => {
957            // Check for Self::AssocType (2-segment path: Self::Item)
958            if tp.path.segments.len() == 2
959                && tp.path.segments[0].ident == "Self"
960                && assoc_types.iter().any(|a| *a == &tp.path.segments[1].ident)
961            {
962                let assoc = &tp.path.segments[1].ident;
963                let impl_t = quote::format_ident!("__ImplT");
964                return syn::parse_quote!(<#impl_t as #trait_name>::#assoc);
965            }
966            // Check for bare Self
967            if tp.path.is_ident("Self") {
968                let impl_t = quote::format_ident!("__ImplT");
969                return syn::parse_quote!(#impl_t);
970            }
971            // Recursively process generic args
972            let mut new_tp = tp.clone();
973            for seg in &mut new_tp.path.segments {
974                if let syn::PathArguments::AngleBracketed(args) = &mut seg.arguments {
975                    for arg in &mut args.args {
976                        if let syn::GenericArgument::Type(inner) = arg {
977                            *inner = rewrite_self_in_type(inner, trait_name, assoc_types);
978                        }
979                    }
980                }
981            }
982            syn::Type::Path(new_tp)
983        }
984        syn::Type::Reference(r) => {
985            let mut new_r = r.clone();
986            new_r.elem = Box::new(rewrite_self_in_type(&r.elem, trait_name, assoc_types));
987            syn::Type::Reference(new_r)
988        }
989        syn::Type::Tuple(t) => {
990            let mut new_t = t.clone();
991            for elem in &mut new_t.elems {
992                *elem = rewrite_self_in_type(elem, trait_name, assoc_types);
993            }
994            syn::Type::Tuple(new_t)
995        }
996        syn::Type::Slice(s) => {
997            let mut new_s = s.clone();
998            new_s.elem = Box::new(rewrite_self_in_type(&s.elem, trait_name, assoc_types));
999            syn::Type::Slice(new_s)
1000        }
1001        syn::Type::Array(a) => {
1002            let mut new_a = a.clone();
1003            new_a.elem = Box::new(rewrite_self_in_type(&a.elem, trait_name, assoc_types));
1004            syn::Type::Array(new_a)
1005        }
1006        syn::Type::Paren(p) => {
1007            let mut new_p = p.clone();
1008            new_p.elem = Box::new(rewrite_self_in_type(&p.elem, trait_name, assoc_types));
1009            syn::Type::Paren(new_p)
1010        }
1011        _ => ty.clone(),
1012    }
1013}
1014
1015/// Check if a type syntactically contains a specific identifier.
1016///
1017/// Used to detect trait type parameters (like `T`) in method signatures so that
1018/// appropriate `TryFromSexp` or `IntoR` bounds can be added. Recursively walks
1019/// through path segments, generic arguments, references, tuples, slices, and arrays.
1020fn type_contains_ident(ty: &syn::Type, ident: &syn::Ident) -> bool {
1021    match ty {
1022        syn::Type::Path(tp) => {
1023            if tp.path.segments.len() == 1 && tp.path.segments[0].ident == *ident {
1024                return true;
1025            }
1026            for seg in &tp.path.segments {
1027                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
1028                    for arg in &args.args {
1029                        if let syn::GenericArgument::Type(inner) = arg
1030                            && type_contains_ident(inner, ident)
1031                        {
1032                            return true;
1033                        }
1034                    }
1035                }
1036            }
1037            false
1038        }
1039        syn::Type::Reference(r) => type_contains_ident(&r.elem, ident),
1040        syn::Type::Tuple(t) => t.elems.iter().any(|e| type_contains_ident(e, ident)),
1041        syn::Type::Slice(s) => type_contains_ident(&s.elem, ident),
1042        syn::Type::Array(a) => type_contains_ident(&a.elem, ident),
1043        syn::Type::Paren(p) => type_contains_ident(&p.elem, ident),
1044        _ => false,
1045    }
1046}
1047
1048/// Extra trait bounds inferred from method signatures.
1049///
1050/// For generic traits and methods that reference `Self` or associated types,
1051/// the generated shim and vtable builder functions need additional bounds
1052/// beyond `__ImplT: TraitName`. This struct collects those bounds.
1053struct ExtraBounds {
1054    /// Bounds added directly to `__ImplT` (e.g., `IntoR` when methods return `Self`,
1055    /// or `TypedExternal + Send + 'static` when methods take `&Self` parameters).
1056    impl_bounds: Vec<TokenStream>,
1057    /// Where clause predicates for complex types (e.g.,
1058    /// `<__ImplT as Trait>::Item: IntoR` or `Vec<T>: TryFromSexp`).
1059    where_predicates: Vec<TokenStream>,
1060}
1061
1062/// Compute extra bounds needed for the shim and build_vtable functions.
1063///
1064/// - Methods returning `Self` → `__ImplT: IntoR`
1065/// - Methods with `&Self` params → `__ImplT: TypedExternal + Send + 'static`
1066/// - Methods returning types with `Self::AssocType` or trait type params →
1067///   full rewritten return type `: IntoR` (e.g., `Option<<__ImplT as RIterator>::Item>: IntoR`)
1068/// - Methods with trait type params in params →
1069///   full param type `: TryFromSexp` (e.g., `Vec<T>: TryFromSexp`)
1070fn compute_extra_bounds(
1071    methods: &[MethodInfo],
1072    trait_name: &syn::Ident,
1073    assoc_types: &[&syn::Ident],
1074    trait_param_idents: &[&syn::Ident],
1075) -> ExtraBounds {
1076    let mut impl_bounds = Vec::new();
1077    let mut where_predicates = Vec::new();
1078
1079    let mut needs_into_r = false;
1080    let mut needs_typed_external = false;
1081
1082    // Track full types needing bounds (deduplicated by token string)
1083    let mut return_type_bound_keys: std::collections::BTreeMap<String, syn::Type> =
1084        Default::default();
1085    let mut param_type_bound_keys: std::collections::BTreeMap<String, syn::Type> =
1086        Default::default();
1087
1088    for method in methods {
1089        // Bare Self in returns → __ImplT: IntoR (as impl bound)
1090        if method.return_type.as_ref().is_some_and(type_contains_self) {
1091            needs_into_r = true;
1092        }
1093        // &Self in params → __ImplT: TypedExternal + 'static
1094        if method.param_types.iter().any(|ty| param_is_self_ref(ty).0) {
1095            needs_typed_external = true;
1096        }
1097
1098        // Full return type bounds for Self::AssocType and/or trait type params.
1099        // Instead of bare `<__ImplT as Trait>::Item: IntoR`, we add the FULL rewritten
1100        // return type: `Option<<__ImplT as Trait>::Item>: IntoR`, `Vec<T>: IntoR`, etc.
1101        // This is required because IntoR impls are concrete (no blanket `Option<T: IntoR>: IntoR`).
1102        if let Some(ref ret_ty) = method.return_type {
1103            let has_assoc = assoc_types
1104                .iter()
1105                .any(|a| type_contains_self_assoc(ret_ty, a));
1106            let has_param = trait_param_idents
1107                .iter()
1108                .any(|p| type_contains_ident(ret_ty, p));
1109            if has_assoc || has_param {
1110                let rewritten = rewrite_self_in_type(ret_ty, trait_name, assoc_types);
1111                let key = quote::quote!(#rewritten).to_string();
1112                return_type_bound_keys.entry(key).or_insert(rewritten);
1113            }
1114        }
1115
1116        // Full param type bounds for trait type params.
1117        // Instead of bare `T: TryFromSexp`, we add `Vec<T>: TryFromSexp` etc.
1118        for param_ty in &method.param_types {
1119            if !param_is_self_ref(param_ty).0 {
1120                let has_param = trait_param_idents
1121                    .iter()
1122                    .any(|p| type_contains_ident(param_ty, p));
1123                if has_param {
1124                    let key = quote::quote!(#param_ty).to_string();
1125                    param_type_bound_keys.entry(key).or_insert(param_ty.clone());
1126                }
1127            }
1128        }
1129    }
1130
1131    if needs_into_r {
1132        impl_bounds.push(quote::quote! { ::miniextendr_api::IntoR });
1133    }
1134    if needs_typed_external {
1135        impl_bounds.push(quote::quote! { ::miniextendr_api::TypedExternal + Send + 'static });
1136    }
1137
1138    // Add full return type bounds: RewrittenType: IntoR
1139    for ty in return_type_bound_keys.values() {
1140        where_predicates.push(quote::quote! {
1141            #ty: ::miniextendr_api::IntoR
1142        });
1143    }
1144
1145    // Add full param type bounds: ParamType: TryFromSexp, Error: Display
1146    for ty in param_type_bound_keys.values() {
1147        where_predicates.push(quote::quote! {
1148            #ty: ::miniextendr_api::TryFromSexp
1149        });
1150        where_predicates.push(quote::quote! {
1151            <#ty as ::miniextendr_api::TryFromSexp>::Error: ::std::fmt::Display
1152        });
1153    }
1154
1155    ExtraBounds {
1156        impl_bounds,
1157        where_predicates,
1158    }
1159}
1160
1161/// Build combined where predicates from the trait's own where clause and computed extra bounds.
1162///
1163/// Merges the original trait-level where clause predicates with the extra
1164/// bounds computed from method signatures (e.g., `IntoR` for return types,
1165/// `TryFromSexp` for parameters containing trait type params).
1166///
1167/// Returns a flat list of predicates suitable for use in a `where` clause.
1168fn build_where_predicates(
1169    trait_where_clause: &Option<syn::WhereClause>,
1170    extra_bounds: &ExtraBounds,
1171) -> Vec<TokenStream> {
1172    let mut all = Vec::new();
1173    if let Some(wc) = trait_where_clause {
1174        for pred in &wc.predicates {
1175            all.push(quote::quote! { #pred });
1176        }
1177    }
1178    all.extend(extra_bounds.where_predicates.iter().cloned());
1179    all
1180}
1181
1182/// Extract method information from a trait method definition.
1183///
1184/// Parses the method signature to determine receiver type, parameter names/types,
1185/// return type, and any `#[miniextendr(...)]` attributes like `skip` and `r_name`.
1186/// Parameters with non-ident patterns are assigned synthetic names (`arg0`, `arg1`, etc.).
1187fn extract_method_info(method: &syn::TraitItemFn) -> syn::Result<MethodInfo> {
1188    let name = method.sig.ident.clone();
1189
1190    // Check for #[miniextendr(skip)] and #[miniextendr(r_name = "...")]
1191    let mut skip = false;
1192    let mut r_name: Option<String> = None;
1193    for attr in &method.attrs {
1194        if !attr.path().is_ident("miniextendr") {
1195            continue;
1196        }
1197        attr.parse_nested_meta(|meta| {
1198            if meta.path.is_ident("skip") {
1199                skip = true;
1200            } else if meta.path.is_ident("r_name") {
1201                let value: syn::LitStr = meta.value()?.parse()?;
1202                r_name = Some(value.value());
1203            } else {
1204                return Err(meta.error(
1205                    "unknown #[miniextendr] option on trait method; expected `skip` or `r_name`",
1206                ));
1207            }
1208            Ok(())
1209        })?;
1210    }
1211
1212    // Check for receiver
1213    let (has_self, is_mut) = method.sig.inputs.first().map_or((false, false), |arg| {
1214        if let syn::FnArg::Receiver(r) = arg {
1215            (true, r.mutability.is_some())
1216        } else {
1217            (false, false)
1218        }
1219    });
1220
1221    // Extract parameters (skip self if present)
1222    let skip_count = if has_self { 1 } else { 0 };
1223    let mut param_types = Vec::new();
1224    let mut param_names = Vec::new();
1225    for (i, arg) in method.sig.inputs.iter().skip(skip_count).enumerate() {
1226        if let syn::FnArg::Typed(pat_type) = arg {
1227            param_types.push((*pat_type.ty).clone());
1228            if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1229                param_names.push(pat_ident.ident.clone());
1230            } else {
1231                param_names.push(quote::format_ident!("arg{}", i));
1232            }
1233        }
1234    }
1235
1236    // Extract return type
1237    let return_type = match &method.sig.output {
1238        syn::ReturnType::Default => None,
1239        syn::ReturnType::Type(_, ty) => {
1240            // Check if it's unit type ()
1241            if matches!(ty.as_ref(), syn::Type::Tuple(t) if t.elems.is_empty()) {
1242                None
1243            } else {
1244                Some((**ty).clone())
1245            }
1246        }
1247    };
1248
1249    Ok(MethodInfo {
1250        name,
1251        has_self,
1252        is_mut,
1253        param_types,
1254        param_names,
1255        return_type,
1256        skip,
1257        r_name,
1258    })
1259}
1260
1261#[cfg(test)]
1262mod tests;
1263// endregion