Skip to main content

miniextendr_macros/miniextendr_impl/
vctrs_class.rs

1//! Vctrs class R wrapper generator.
2
3use super::{ParsedImpl, VctrsKind};
4
5/// Generates the complete R wrapper string for a vctrs-compatible S3 class.
6///
7/// This is used when an `impl` block is annotated with `#[miniextendr(vctrs)]`.
8/// Unlike the `#[derive(Vctrs)]` macro (which generates standalone S3 methods from
9/// struct attributes), this generator produces class wrappers from `impl` block methods.
10///
11/// Produces the following R code:
12/// - Constructor: `new_<class>(...)` that calls the Rust `new` constructor, then wraps
13///   the result with `vctrs::new_vctr()`, `vctrs::new_rcrd()`, or `vctrs::new_list_of()`
14///   depending on the `VctrsKind`
15/// - `vec_ptype_abbr.<class>`: compact abbreviation for printing (if `abbr` is specified)
16/// - `vec_ptype2.<class>.<class>`: self-coercion prototype (returns empty typed vector)
17/// - `vec_cast.<class>.<class>`: identity cast (returns `x` unchanged)
18/// - Instance methods: S3 generics + `<generic>.<class>` methods, with support for
19///   vctrs protocol overrides via `#[miniextendr(vctrs_protocol = "...")]` and
20///   double-dispatch class suffixes via `#[miniextendr(class = "...")]`
21/// - Static methods: regular functions named `<class>_<method>(...)`
22///
23/// Roxygen2 documentation and `@importFrom vctrs ...` tags are generated automatically.
24pub fn generate_vctrs_r_wrapper(parsed_impl: &ParsedImpl) -> String {
25    use crate::r_class_formatter::{
26        ClassDocBuilder, MethodDocBuilder, ParsedImplExt, emit_s3_generic_guard,
27        should_export_from_tags,
28    };
29
30    let class_name = parsed_impl.class_name();
31    let type_ident = &parsed_impl.type_ident;
32    let class_doc_tags = &parsed_impl.doc_tags;
33    let vctrs_attrs = &parsed_impl.vctrs_attrs;
34    let should_export =
35        should_export_from_tags(class_doc_tags, parsed_impl.noexport || parsed_impl.internal);
36
37    // Constructor name follows vctrs convention: new_<class>
38    let ctor_name = format!("new_{}", class_name.to_lowercase());
39
40    let mut lines = Vec::new();
41
42    // Constructor with combined class and constructor documentation
43    if let Some(ctx) = parsed_impl.constructor_context() {
44        lines.push(ctx.source_comment(type_ident));
45        let mut ctor_doc_tags = Vec::new();
46        ctor_doc_tags.extend(class_doc_tags.iter().cloned());
47        ctor_doc_tags.extend(ctx.method.doc_tags.iter().cloned());
48
49        lines.extend(
50            ClassDocBuilder::new(&class_name, type_ident, &ctor_doc_tags, "vctrs S3")
51                .with_imports("@importFrom vctrs new_vctr new_rcrd new_list_of vec_ptype2 vec_cast vec_ptype_abbr")
52                .with_export_control(parsed_impl.internal, parsed_impl.noexport)
53                .build(),
54        );
55        // Inject lifecycle imports from methods into class-level roxygen block
56        if let Some(lc_import) = crate::lifecycle::collect_lifecycle_imports(
57            parsed_impl
58                .methods
59                .iter()
60                .filter_map(|m| m.method_attrs.lifecycle.as_ref()),
61        ) {
62            let insert_pos = lines.len().saturating_sub(1);
63            lines.insert(insert_pos, format!("#' {}", lc_import));
64        }
65
66        // Generate constructor body based on vctrs kind
67        lines.push(format!("{} <- function({}) {{", ctor_name, ctx.params));
68        for line in ctx.missing_prelude() {
69            lines.push(format!("  {}", line));
70        }
71        for check in ctx.precondition_checks() {
72            lines.push(format!("  {}", check));
73        }
74        // Inject match.arg validation for match_arg/choices params
75        for line in ctx.match_arg_prelude() {
76            lines.push(format!("  {}", line));
77        }
78        lines.push(format!("  .val <- {}", ctx.static_call()));
79        lines.extend(crate::method_return_builder::condition_check_lines("  "));
80        lines.push("  data <- .val".to_string());
81
82        match vctrs_attrs.kind {
83            VctrsKind::Vctr => {
84                // Build new_vctr call with optional inherit_base_type
85                let inherit_arg = match vctrs_attrs.inherit_base_type {
86                    Some(true) => ", inherit_base_type = TRUE",
87                    Some(false) => ", inherit_base_type = FALSE",
88                    None => "",
89                };
90                lines.push(format!(
91                    "  vctrs::new_vctr(data, class = \"{}\"{})",
92                    class_name, inherit_arg
93                ));
94            }
95            VctrsKind::Rcrd => {
96                // Record type - data should be a list
97                lines.push(format!(
98                    "  vctrs::new_rcrd(data, class = \"{}\")",
99                    class_name
100                ));
101            }
102            VctrsKind::ListOf => {
103                // list_of - needs ptype
104                let ptype_arg = vctrs_attrs
105                    .ptype
106                    .as_ref()
107                    .map(|p| format!(", ptype = {}", p))
108                    .unwrap_or_default();
109                lines.push(format!(
110                    "  vctrs::new_list_of(data, class = \"{}\"{})",
111                    class_name, ptype_arg
112                ));
113            }
114        }
115        lines.push("}".to_string());
116        lines.push(String::new());
117    }
118
119    // vec_ptype_abbr for compact printing (if abbr is specified)
120    if let Some(abbr) = &vctrs_attrs.abbr {
121        lines.push(format!("#' @rdname {}", class_name));
122        lines.push(format!("#' @method vec_ptype_abbr {}", class_name));
123        if should_export {
124            lines.push("#' @export".to_string());
125        }
126        lines.push(format!(
127            "vec_ptype_abbr.{} <- function(x, ...) \"{}\"",
128            class_name, abbr
129        ));
130        lines.push(String::new());
131    }
132
133    // Self-coercion methods (required for vctrs to work properly)
134    // vec_ptype2.<class>.<class> - returns prototype for combining same types
135    lines.push(format!("#' @rdname {}", class_name));
136    lines.push(format!(
137        "#' @method vec_ptype2 {}.{}",
138        class_name, class_name
139    ));
140    lines.push(format!("#' @param x A {} vector.", class_name));
141    lines.push(format!("#' @param y A {} vector.", class_name));
142    lines.push("#' @param ... Additional arguments (unused).".to_string());
143    if should_export {
144        lines.push("#' @export".to_string());
145    }
146    match vctrs_attrs.kind {
147        VctrsKind::Vctr => {
148            let base_type = vctrs_attrs
149                .base
150                .as_ref()
151                .map(|b| format!("{}()", b))
152                .unwrap_or_else(|| "double()".to_string());
153            let inherit_arg = match vctrs_attrs.inherit_base_type {
154                Some(true) => ", inherit_base_type = TRUE",
155                Some(false) => ", inherit_base_type = FALSE",
156                None => "",
157            };
158            lines.push(format!(
159                "vec_ptype2.{c}.{c} <- function(x, y, ...) vctrs::new_vctr({base}, class = \"{c}\"{inherit})",
160                c = class_name,
161                base = base_type,
162                inherit = inherit_arg
163            ));
164        }
165        VctrsKind::Rcrd => {
166            // For records, return empty record with same field structure
167            lines.push(format!(
168                "vec_ptype2.{c}.{c} <- function(x, y, ...) x[0]",
169                c = class_name
170            ));
171        }
172        VctrsKind::ListOf => {
173            let ptype_arg = vctrs_attrs
174                .ptype
175                .as_ref()
176                .map(|p| format!(", ptype = {}", p))
177                .unwrap_or_default();
178            lines.push(format!(
179                "vec_ptype2.{c}.{c} <- function(x, y, ...) vctrs::new_list_of(list(), class = \"{c}\"{ptype})",
180                c = class_name,
181                ptype = ptype_arg
182            ));
183        }
184    }
185    lines.push(String::new());
186
187    // vec_cast.<class>.<class> - identity cast (no-op for same type)
188    lines.push(format!("#' @rdname {}", class_name));
189    lines.push(format!("#' @method vec_cast {}.{}", class_name, class_name));
190    lines.push(format!("#' @param x A {} vector to cast.", class_name));
191    lines.push(format!("#' @param to A {} prototype.", class_name));
192    lines.push("#' @param ... Additional arguments (unused).".to_string());
193    if should_export {
194        lines.push("#' @export".to_string());
195    }
196    lines.push(format!(
197        "vec_cast.{c}.{c} <- function(x, to, ...) x",
198        c = class_name
199    ));
200    lines.push(String::new());
201
202    // Instance methods as S3 generics + methods
203    for ctx in parsed_impl.instance_method_contexts() {
204        lines.push(ctx.source_comment(type_ident));
205        // vctrs protocol override: use the protocol name as the S3 generic
206        let is_protocol = ctx.method.method_attrs.vctrs_protocol.is_some();
207        let generic_name = if let Some(ref proto) = ctx.method.method_attrs.vctrs_protocol {
208            proto.clone()
209        } else {
210            ctx.generic_name()
211        };
212        // Use custom class suffix if provided (for double-dispatch patterns like vec_ptype2.a.b)
213        let method_class_suffix = ctx
214            .class_suffix()
215            .map(|s| s.to_string())
216            .unwrap_or_else(|| class_name.clone());
217        let s3_method_name = format!("{}.{}", generic_name, method_class_suffix);
218        let full_params = ctx.instance_formals(true); // adds x, ..., params
219
220        // Only create the S3 generic if no generic/class override was provided
221        // vctrs protocol methods use existing generics from the vctrs package
222        if !is_protocol && !ctx.has_generic_override() && !ctx.has_class_override() {
223            lines.push(format!("#' @title S3 generic for `{}`", generic_name));
224            lines.push(format!("#' @description S3 generic for `{}`", generic_name));
225            lines.push(format!("#' @rdname {}", class_name));
226            // Use class-qualified name to avoid duplicate alias when multiple
227            // classes define the same S3 generic.
228            lines.push(format!("#' @name {}.{}", generic_name, class_name));
229            lines.push("#' @param x An object".to_string());
230            lines.push("#' @param ... Additional arguments passed to methods".to_string());
231            lines.push(crate::roxygen::method_source_tag(
232                type_ident,
233                &ctx.method.ident,
234            ));
235            if should_export {
236                lines.push("#' @export".to_string());
237            }
238            lines.push(emit_s3_generic_guard(&generic_name));
239            lines.push(String::new());
240        }
241
242        // Then create the S3 method
243        let qualified_name = format!("{}.{}", generic_name, method_class_suffix);
244        let mx_doc = ctx.match_arg_doc_placeholders();
245        let method_doc =
246            MethodDocBuilder::new(&class_name, &generic_name, type_ident, &ctx.method.doc_tags)
247                .with_r_params(&ctx.params)
248                .with_match_arg_doc_placeholders(&mx_doc)
249                .with_r_name(qualified_name);
250        lines.extend(method_doc.build());
251        lines.push(format!(
252            "#' @method {} {}",
253            generic_name, method_class_suffix
254        ));
255        if should_export {
256            lines.push("#' @export".to_string());
257        }
258        lines.push(format!(
259            "{} <- function({}) {{",
260            s3_method_name, full_params
261        ));
262
263        let what = format!("{}.{}", generic_name, class_name);
264        ctx.emit_method_prelude(&mut lines, "  ", &what);
265
266        let call = ctx.instance_call("x");
267        let strategy = crate::ReturnStrategy::for_method(ctx.method);
268        let return_builder = crate::MethodReturnBuilder::new(call)
269            .with_strategy(strategy)
270            .with_class_name(class_name.clone())
271            .with_chain_var("x".to_string());
272        lines.extend(return_builder.build_s3_body());
273
274        lines.push("}".to_string());
275        lines.push(String::new());
276    }
277
278    // Static methods as regular functions, or as vctrs protocol S3 methods when
279    // `#[miniextendr(vctrs(protocol))]` is set (e.g. `vctrs(format)` → `format.<Class>`).
280    for ctx in parsed_impl.static_method_contexts() {
281        lines.push(ctx.source_comment(type_ident));
282
283        let is_protocol = ctx.method.method_attrs.vctrs_protocol.is_some();
284        let fn_name = if let Some(ref proto) = ctx.method.method_attrs.vctrs_protocol {
285            // vctrs protocol override: emit as `<protocol>.<Class>` S3 method
286            format!("{}.{}", proto, class_name)
287        } else {
288            let method_name = ctx.method.r_method_name();
289            format!("{}_{}", class_name.to_lowercase(), method_name)
290        };
291        let r_name = fn_name.clone();
292
293        let mx_doc = ctx.match_arg_doc_placeholders();
294        let method_name = ctx.method.r_method_name();
295        if is_protocol {
296            let proto = ctx.method.method_attrs.vctrs_protocol.as_ref().unwrap();
297            let method_doc =
298                MethodDocBuilder::new(&class_name, &method_name, type_ident, &ctx.method.doc_tags)
299                    .with_r_params(&ctx.params)
300                    .with_match_arg_doc_placeholders(&mx_doc)
301                    .with_r_name(r_name.clone());
302            lines.extend(method_doc.build());
303            lines.push(format!("#' @method {} {}", proto, class_name));
304            if should_export {
305                lines.push("#' @export".to_string());
306            }
307        } else {
308            let method_doc =
309                MethodDocBuilder::new(&class_name, &method_name, type_ident, &ctx.method.doc_tags)
310                    .with_r_params(&ctx.params)
311                    .with_match_arg_doc_placeholders(&mx_doc)
312                    .with_r_name(r_name.clone());
313            lines.extend(method_doc.build());
314        }
315
316        // Protocol methods accept `...` so `format(x, nsmall = 2)` and similar
317        // S3 dispatch calls with extra arguments don't error with "unused argument".
318        // The `...` is silently dropped; the underlying Rust function has a fixed signature.
319        let formals = if is_protocol {
320            format!("{}, ...", ctx.params)
321        } else {
322            ctx.params.to_string()
323        };
324        lines.push(format!("{} <- function({}) {{", fn_name, formals));
325
326        ctx.emit_method_prelude(&mut lines, "  ", &fn_name);
327
328        let strategy = crate::ReturnStrategy::for_method(ctx.method);
329        let return_builder = crate::MethodReturnBuilder::new(ctx.static_call())
330            .with_strategy(strategy)
331            .with_class_name(class_name.clone());
332        lines.extend(return_builder.build_s3_body());
333
334        lines.push("}".to_string());
335        lines.push(String::new());
336    }
337
338    lines.join("\n")
339}