Skip to main content

miniextendr_macros/
r_preconditions.rs

1//! R-side precondition generation for type checking.
2//!
3//! Generates `stopifnot()` checks in R wrapper functions that run BEFORE the `.Call()` boundary.
4//! This gives users clear, idiomatic R error messages with proper stack traces instead of
5//! Rust panic messages.
6//!
7//! Each assertion checks ONE thing with a precise error message:
8//!
9//! ```r
10//! add <- function(a, b) {
11//!   stopifnot(
12//!     "'a' must be numeric, logical, or raw" = is.numeric(a) || is.logical(a) || is.raw(a),
13//!     "'a' must have length 1" = length(a) == 1L,
14//!     "'b' must be numeric, logical, or raw" = is.numeric(b) || is.logical(b) || is.raw(b),
15//!     "'b' must have length 1" = length(b) == 1L
16//!   )
17//!   .Call(C_add, .call = match.call(), a, b)
18//! }
19//! ```
20
21use std::collections::HashSet;
22
23/// A single `stopifnot()` assertion: `"message" = condition`.
24///
25/// When formatted, produces a named argument for R's `stopifnot()`:
26/// `"'x' must be numeric" = is.numeric(x)`.
27struct RAssertion {
28    /// Human-readable error message shown when the assertion fails.
29    message: String,
30    /// R expression that must evaluate to `TRUE` for the check to pass.
31    condition: String,
32}
33
34impl RAssertion {
35    /// Create a new assertion with the given error message and R condition expression.
36    fn new(message: impl Into<String>, condition: impl Into<String>) -> Self {
37        Self {
38            message: message.into(),
39            condition: condition.into(),
40        }
41    }
42
43    /// Format as a `stopifnot()` named argument: `"message" = condition`.
44    fn to_stopifnot_arg(&self) -> String {
45        format!("\"{}\" = {}", self.message, self.condition)
46    }
47
48    /// Wrap for nullable: prepend `is.null(param) || ` to the condition,
49    /// and adjust the message to mention NULL.
50    fn nullable(self, param: &str) -> Self {
51        let message = if self.message.contains("must be ") {
52            // "'x' must be character" → "'x' must be NULL or character"
53            self.message.replacen("must be ", "must be NULL or ", 1)
54        } else if self.message.contains("must have ") {
55            // "'x' must have length 1" → "'x' must be NULL or have length 1"
56            self.message
57                .replacen("must have ", "must be NULL or have ", 1)
58        } else {
59            format!("{} (or NULL)", self.message)
60        };
61        Self {
62            message,
63            condition: format!("is.null({}) || {}", param, self.condition),
64        }
65    }
66}
67
68/// Classification of an R-side type check for a function parameter.
69///
70/// Each variant maps to a specific set of `stopifnot()` assertions. Numeric checks
71/// use a broad predicate (`is.numeric || is.logical || is.raw`) because R coerces
72/// logical to numeric freely and raw to integer is valid for byte-sized types.
73/// Borderline cases (e.g., raw to i64 in strict mode) pass the precondition and
74/// reach Rust's strict checker, which produces better contextual error messages.
75enum RTypeCheck {
76    /// Numeric scalar: type check + length-1 check (2 assertions).
77    /// Used for `i32`, `f64`, `f32`, `i8`, `i16`, `i64`, `isize`.
78    ScalarNumeric,
79    /// Non-negative numeric scalar: type + length-1 + `>= 0` (3 assertions).
80    /// Used for `u16`, `u32`, `u64`, `usize`.
81    ScalarNonNeg,
82    /// Non-numeric scalar: `is.<type>(x)` + length-1 check (2 assertions).
83    /// The string is the R type predicate name (e.g., `"logical"`, `"character"`).
84    Scalar(&'static str),
85    /// Numeric vector: type check only, no length constraint (1 assertion).
86    VectorNumeric,
87    /// Non-numeric vector: `is.<type>(x)` only (1 assertion).
88    /// The string is the R type predicate name.
89    Vector(&'static str),
90    /// Nullable wrapper around an inner check: prepends `is.null(x) ||` to each assertion
91    /// and adjusts messages to mention NULL.
92    Nullable(Box<RTypeCheck>),
93    /// List check: `is.list(x)` (1 assertion).
94    /// Used for `HashMap`, `BTreeMap`, `NamedList`, `List`, `ListMut`.
95    List,
96}
97
98/// Build the R expression for the numeric type predicate.
99///
100/// Returns `"is.numeric(p) || is.logical(p) || is.raw(p)"` for a given parameter `p`.
101/// This broad predicate matches R's coercion rules: logical coerces to numeric freely,
102/// and raw is accepted because it represents byte-level data.
103fn numeric_type_check(param: &str) -> String {
104    format!(
105        "is.numeric({p}) || is.logical({p}) || is.raw({p})",
106        p = param
107    )
108}
109
110impl RTypeCheck {
111    /// Produce the individual `stopifnot()` assertions for this type check.
112    ///
113    /// Returns one or more `RAssertion` values, each representing a single
114    /// `"message" = condition` entry in the `stopifnot()` call. The `param`
115    /// argument is the R parameter name to use in messages and conditions.
116    fn assertions(&self, param: &str) -> Vec<RAssertion> {
117        match self {
118            RTypeCheck::ScalarNumeric => vec![
119                RAssertion::new(
120                    format!("'{}' must be numeric, logical, or raw", param),
121                    numeric_type_check(param),
122                ),
123                RAssertion::new(
124                    format!("'{}' must have length 1", param),
125                    format!("length({}) == 1L", param),
126                ),
127            ],
128            RTypeCheck::ScalarNonNeg => vec![
129                RAssertion::new(
130                    format!("'{}' must be numeric, logical, or raw", param),
131                    numeric_type_check(param),
132                ),
133                RAssertion::new(
134                    format!("'{}' must have length 1", param),
135                    format!("length({}) == 1L", param),
136                ),
137                RAssertion::new(
138                    format!("'{}' must be non-negative", param),
139                    // raw is always non-negative; guard with is.raw() to avoid
140                    // "comparison not implemented" error for raw values
141                    format!("is.raw({p}) || {p} >= 0", p = param),
142                ),
143            ],
144            RTypeCheck::Scalar(r_type) => vec![
145                RAssertion::new(
146                    format!("'{}' must be {}", param, r_type),
147                    format!("is.{}({})", r_type, param),
148                ),
149                RAssertion::new(
150                    format!("'{}' must have length 1", param),
151                    format!("length({}) == 1L", param),
152                ),
153            ],
154            RTypeCheck::VectorNumeric => vec![RAssertion::new(
155                format!("'{}' must be numeric, logical, or raw", param),
156                numeric_type_check(param),
157            )],
158            RTypeCheck::Vector(r_type) => vec![RAssertion::new(
159                format!("'{}' must be {}", param, r_type),
160                format!("is.{}({})", r_type, param),
161            )],
162            RTypeCheck::Nullable(inner) => inner
163                .assertions(param)
164                .into_iter()
165                .map(|a| a.nullable(param))
166                .collect(),
167            RTypeCheck::List => vec![RAssertion::new(
168                format!("'{}' must be a list", param),
169                format!("is.list({})", param),
170            )],
171        }
172    }
173}
174
175/// Map a Rust type to its R-side type check, if applicable.
176///
177/// Returns `None` for types that should skip precondition checks (SEXP, Dots, ExternalPtr, etc.).
178fn r_check_for_type(ty: &syn::Type) -> Option<RTypeCheck> {
179    match ty {
180        syn::Type::Path(type_path) => r_check_for_type_path(type_path),
181        syn::Type::Reference(type_ref) => r_check_for_reference(type_ref),
182        _ => None,
183    }
184}
185
186/// Map a `syn::TypePath` to its R-side type check.
187///
188/// Handles the most common case: simple types (`i32`, `String`, `bool`),
189/// generic wrappers (`Vec<T>`, `Option<T>`), map types, and skip types.
190/// Returns `None` for types that cannot be prechecked from R.
191fn r_check_for_type_path(type_path: &syn::TypePath) -> Option<RTypeCheck> {
192    let segment = type_path.path.segments.last()?;
193    let ident = segment.ident.to_string();
194
195    match ident.as_str() {
196        // Numeric scalars (accepts numeric, logical, and raw via R coercion)
197        "i32" | "f64" | "f32" | "i8" | "i16" | "i64" | "isize" => Some(RTypeCheck::ScalarNumeric),
198
199        // Unsigned numeric scalars (non-negative constraint)
200        "u16" | "u32" | "u64" | "usize" => Some(RTypeCheck::ScalarNonNeg),
201
202        // Logical scalar
203        "bool" | "Rbool" | "Rboolean" => Some(RTypeCheck::Scalar("logical")),
204
205        // Character scalar
206        "String" | "char" | "PathBuf" => Some(RTypeCheck::Scalar("character")),
207
208        // Raw scalar
209        "u8" => Some(RTypeCheck::Scalar("raw")),
210
211        // Complex scalar
212        "Rcomplex" => Some(RTypeCheck::Scalar("complex")),
213
214        // Option<T> → Nullable
215        "Option" => {
216            let inner_ty = extract_single_generic_arg(segment)?;
217            r_check_for_type(inner_ty).map(|inner| RTypeCheck::Nullable(Box::new(inner)))
218        }
219
220        // Vec<T> → Vector (depends on element type)
221        "Vec" => {
222            let inner_ty = extract_single_generic_arg(segment)?;
223            r_check_for_vec_element(inner_ty)
224        }
225
226        // Map types and named list → List
227        "HashMap" | "BTreeMap" | "NamedList" => Some(RTypeCheck::List),
228
229        // List (bare) → List
230        "List" | "ListMut" => Some(RTypeCheck::List),
231
232        // Skip types: SEXP, Dots, Missing, ExternalPtr, RLogical, etc.
233        "SEXP" | "Dots" | "Missing" | "ExternalPtr" | "OwnedProtect" => None,
234
235        // Unknown type → skip (let Rust side validate)
236        _ => None,
237    }
238}
239
240/// Map a reference type to its R-side type check.
241///
242/// Handles `&str` and `&Path` (character scalar), `&[T]` (vector based on element type),
243/// and `&Dots` (skipped). Returns `None` for unrecognized reference types.
244fn r_check_for_reference(type_ref: &syn::TypeReference) -> Option<RTypeCheck> {
245    match type_ref.elem.as_ref() {
246        // &str → character scalar
247        syn::Type::Path(tp) => {
248            let seg = tp.path.segments.last()?;
249            match seg.ident.to_string().as_str() {
250                "str" => Some(RTypeCheck::Scalar("character")),
251                "Path" => Some(RTypeCheck::Scalar("character")),
252                "Dots" => None,
253                _ => None,
254            }
255        }
256        // &[T] → vector check based on element type
257        syn::Type::Slice(slice) => r_check_for_vec_element(&slice.elem),
258        _ => None,
259    }
260}
261
262/// Map a `Vec<T>` or `&[T]` element type to the appropriate vector type check.
263///
264/// Numeric elements produce `VectorNumeric`, `bool` produces `Vector("logical")`,
265/// `String` produces `Vector("character")`, etc. Handles nested `Option<T>` for
266/// nullable element types (e.g., `Vec<Option<String>>` becomes character vector).
267fn r_check_for_vec_element(elem_ty: &syn::Type) -> Option<RTypeCheck> {
268    let syn::Type::Path(tp) = elem_ty else {
269        return None;
270    };
271    let seg = tp.path.segments.last()?;
272    let ident = seg.ident.to_string();
273
274    match ident.as_str() {
275        // Numeric vectors (accepts numeric, logical, and raw via R coercion)
276        "i32" | "f64" | "f32" | "i8" | "i16" | "u16" | "u32" | "i64" | "u64" | "isize"
277        | "usize" => Some(RTypeCheck::VectorNumeric),
278
279        // Logical vector
280        "bool" => Some(RTypeCheck::Vector("logical")),
281
282        // Character vector
283        "String" => Some(RTypeCheck::Vector("character")),
284
285        // Raw vector
286        "u8" => Some(RTypeCheck::Vector("raw")),
287
288        // Complex vector
289        "Rcomplex" => Some(RTypeCheck::Vector("complex")),
290
291        // Vec<Option<T>> — e.g., Vec<Option<String>> for nullable strings
292        "Option" => {
293            let inner = extract_single_generic_arg(seg)?;
294            // Vec<Option<String>> → character, Vec<Option<i32>> → numeric, etc.
295            r_check_for_vec_element(inner)
296        }
297
298        _ => None,
299    }
300}
301
302/// Extract the single generic type argument from a path segment.
303///
304/// e.g., `Option<String>` → `String`, `Vec<i32>` → `i32`
305fn extract_single_generic_arg(segment: &syn::PathSegment) -> Option<&syn::Type> {
306    if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments
307        && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
308    {
309        return Some(ty);
310    }
311    None
312}
313
314/// A parameter whose Rust type is not in the static type table.
315///
316/// Currently, fallback params are recorded but no R-side validation is generated
317/// for them -- the Rust-side conversion handles type errors with its own messages.
318#[allow(dead_code)] // Read in tests
319pub struct FallbackParam {
320    /// R-normalized parameter name (e.g., `_dots` becomes `.dots`).
321    pub r_name: String,
322}
323
324/// Output of precondition analysis for a function's parameters.
325///
326/// Contains both the generated R `stopifnot()` code for known types and a list
327/// of parameters with unknown types that were not statically prechecked.
328pub struct PreconditionOutput {
329    /// Lines forming a `stopifnot(...)` call for known types.
330    ///
331    /// Empty if no parameters have known type checks. For a single assertion,
332    /// contains one line (`stopifnot(...)`). For multiple assertions, contains
333    /// `stopifnot(`, indented assertion lines, and `)`.
334    pub static_checks: Vec<String>,
335    /// Parameters with unknown custom types that were not prechecked.
336    #[allow(dead_code)] // Read in tests
337    pub fallback_params: Vec<FallbackParam>,
338}
339
340/// Returns `true` for types that should never get a fallback precheck.
341///
342/// These types are either handled specially by the FFI layer (`SEXP`),
343/// consumed by the macro infrastructure (`Dots`, `Missing`), or managed
344/// internally (`ExternalPtr`, `OwnedProtect`).
345fn is_skip_type(ident: &str) -> bool {
346    matches!(
347        ident,
348        "SEXP" | "Dots" | "Missing" | "ExternalPtr" | "OwnedProtect"
349    )
350}
351
352/// Returns `true` if a type is unknown to the static type table and should
353/// be recorded as a fallback parameter.
354///
355/// Returns `false` for skip types (SEXP, Dots, etc.) and reference types
356/// (which are handled by the static table or skipped).
357fn needs_fallback(ty: &syn::Type) -> bool {
358    match ty {
359        syn::Type::Path(tp) => {
360            let Some(seg) = tp.path.segments.last() else {
361                return false;
362            };
363            !is_skip_type(&seg.ident.to_string())
364        }
365        // References (&str, &[T], &Dots) are handled by static table or skipped
366        syn::Type::Reference(_) => false,
367        _ => false,
368    }
369}
370
371/// Build precondition checks for a function's parameters.
372///
373/// Returns:
374/// - **`static_checks`**: Lines forming a `stopifnot(...)` call for known types
375/// - **`fallback_params`**: Parameters needing validation (unknown custom types)
376///
377/// Static checks produce R-side `stopifnot()`:
378/// ```r
379/// stopifnot(
380///   "'a' must be numeric, logical, or raw" = is.numeric(a) || is.logical(a) || is.raw(a),
381///   "'a' must have length 1" = length(a) == 1L
382/// )
383/// ```
384///
385/// Skips:
386/// - `self`/`&self`/`&mut self` (receiver args)
387/// - Parameters in `skip_params` (e.g., match_arg params already validated)
388/// - Skip types (SEXP, Dots, ExternalPtr, etc.)
389pub fn build_precondition_checks(
390    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
391    skip_params: &HashSet<String>,
392) -> PreconditionOutput {
393    let mut args = Vec::new();
394    let mut fallback_params = Vec::new();
395
396    for arg in inputs {
397        // Skip receiver (self/&self/&mut self)
398        let syn::FnArg::Typed(pt) = arg else {
399            continue;
400        };
401
402        // Extract parameter name
403        let syn::Pat::Ident(pat_ident) = pt.pat.as_ref() else {
404            continue;
405        };
406
407        // Use the R-normalized name for the check (matches the R formal)
408        let r_name = crate::r_wrapper_builder::normalize_r_arg_ident(&pat_ident.ident).to_string();
409
410        // Skip match_arg params (already validated by match.arg())
411        if skip_params.contains(&r_name) {
412            continue;
413        }
414
415        // Map the Rust type to R assertions (known types)
416        if let Some(check) = r_check_for_type(pt.ty.as_ref()) {
417            for assertion in check.assertions(&r_name) {
418                args.push(assertion.to_stopifnot_arg());
419            }
420        } else if needs_fallback(pt.ty.as_ref()) {
421            // Unknown type → record for potential future validation
422            fallback_params.push(FallbackParam { r_name });
423        }
424    }
425
426    let static_checks = match args.len() {
427        0 => Vec::new(),
428        1 => vec![format!("stopifnot({})", args[0])],
429        _ => {
430            let mut lines = Vec::with_capacity(args.len() + 2);
431            lines.push("stopifnot(".to_string());
432            for (i, arg) in args.iter().enumerate() {
433                let comma = if i < args.len() - 1 { "," } else { "" };
434                lines.push(format!("  {}{}", arg, comma));
435            }
436            lines.push(")".to_string());
437            lines
438        }
439    };
440
441    PreconditionOutput {
442        static_checks,
443        fallback_params,
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    /// Helper to parse a type string into syn::Type
452    fn parse_type(s: &str) -> syn::Type {
453        syn::parse_str(s).unwrap()
454    }
455
456    /// Helper to get assertions for a type
457    fn assertions_for(ty_str: &str, param: &str) -> Vec<RAssertion> {
458        let ty = parse_type(ty_str);
459        r_check_for_type(&ty).unwrap().assertions(param)
460    }
461
462    #[test]
463    fn scalar_numeric_produces_two_assertions() {
464        let asserts = assertions_for("i32", "x");
465        assert_eq!(asserts.len(), 2);
466        assert_eq!(asserts[0].message, "'x' must be numeric, logical, or raw");
467        assert_eq!(
468            asserts[0].condition,
469            "is.numeric(x) || is.logical(x) || is.raw(x)"
470        );
471        assert_eq!(asserts[1].message, "'x' must have length 1");
472        assert_eq!(asserts[1].condition, "length(x) == 1L");
473    }
474
475    #[test]
476    fn all_signed_numeric_types_use_scalar_numeric() {
477        for ty_str in &["i32", "f64", "f32", "i8", "i16", "i64", "isize"] {
478            let asserts = assertions_for(ty_str, "x");
479            assert_eq!(asserts.len(), 2, "{} should produce 2 assertions", ty_str);
480            assert!(
481                asserts[0].condition.contains("is.numeric(x)"),
482                "{} type check",
483                ty_str
484            );
485            assert!(
486                asserts[0].condition.contains("is.logical(x)"),
487                "{} accepts logical",
488                ty_str
489            );
490            assert!(
491                asserts[0].condition.contains("is.raw(x)"),
492                "{} accepts raw",
493                ty_str
494            );
495        }
496    }
497
498    #[test]
499    fn scalar_non_neg_produces_three_assertions() {
500        let asserts = assertions_for("u32", "n");
501        assert_eq!(asserts.len(), 3);
502        assert_eq!(asserts[0].message, "'n' must be numeric, logical, or raw");
503        assert_eq!(asserts[1].message, "'n' must have length 1");
504        assert_eq!(asserts[2].message, "'n' must be non-negative");
505        assert_eq!(asserts[2].condition, "is.raw(n) || n >= 0");
506    }
507
508    #[test]
509    fn all_unsigned_types_use_scalar_non_neg() {
510        for ty_str in &["u16", "u32", "u64", "usize"] {
511            let asserts = assertions_for(ty_str, "x");
512            assert_eq!(asserts.len(), 3, "{} should produce 3 assertions", ty_str);
513            assert!(
514                asserts[2].condition.contains(">= 0"),
515                "{} non-neg check",
516                ty_str
517            );
518        }
519    }
520
521    #[test]
522    fn scalar_logical() {
523        let asserts = assertions_for("bool", "x");
524        assert_eq!(asserts.len(), 2);
525        assert_eq!(asserts[0].message, "'x' must be logical");
526        assert_eq!(asserts[0].condition, "is.logical(x)");
527        assert_eq!(asserts[1].condition, "length(x) == 1L");
528    }
529
530    #[test]
531    fn scalar_character() {
532        for ty_str in &["String", "char", "PathBuf"] {
533            let asserts = assertions_for(ty_str, "s");
534            assert_eq!(asserts.len(), 2);
535            assert_eq!(asserts[0].message, "'s' must be character");
536            assert_eq!(asserts[0].condition, "is.character(s)");
537        }
538    }
539
540    #[test]
541    fn ref_str() {
542        let ty: syn::Type = syn::parse_str("& str").unwrap();
543        let asserts = r_check_for_type(&ty).unwrap().assertions("s");
544        assert_eq!(asserts.len(), 2);
545        assert_eq!(asserts[0].condition, "is.character(s)");
546    }
547
548    #[test]
549    fn scalar_raw() {
550        let asserts = assertions_for("u8", "x");
551        assert_eq!(asserts.len(), 2);
552        assert_eq!(asserts[0].message, "'x' must be raw");
553        assert_eq!(asserts[0].condition, "is.raw(x)");
554    }
555
556    #[test]
557    fn vector_numeric_produces_one_assertion() {
558        for ty_str in &["Vec<f64>", "Vec<i8>", "Vec<i32>", "Vec<i64>"] {
559            let asserts = assertions_for(ty_str, "x");
560            assert_eq!(asserts.len(), 1, "{} should produce 1 assertion", ty_str);
561            assert_eq!(
562                asserts[0].condition,
563                "is.numeric(x) || is.logical(x) || is.raw(x)"
564            );
565        }
566    }
567
568    #[test]
569    fn vector_character() {
570        let asserts = assertions_for("Vec<String>", "x");
571        assert_eq!(asserts.len(), 1);
572        assert_eq!(asserts[0].condition, "is.character(x)");
573    }
574
575    #[test]
576    fn vector_optional_string() {
577        let asserts = assertions_for("Vec<Option<String>>", "x");
578        assert_eq!(asserts.len(), 1);
579        assert_eq!(asserts[0].condition, "is.character(x)");
580    }
581
582    #[test]
583    fn slice_u8() {
584        let ty: syn::Type = syn::parse_str("& [u8]").unwrap();
585        let asserts = r_check_for_type(&ty).unwrap().assertions("x");
586        assert_eq!(asserts.len(), 1);
587        assert_eq!(asserts[0].condition, "is.raw(x)");
588    }
589
590    #[test]
591    fn nullable_wraps_inner_assertions() {
592        let asserts = assertions_for("Option<i32>", "x");
593        assert_eq!(asserts.len(), 2);
594        assert_eq!(
595            asserts[0].message,
596            "'x' must be NULL or numeric, logical, or raw"
597        );
598        assert_eq!(
599            asserts[0].condition,
600            "is.null(x) || is.numeric(x) || is.logical(x) || is.raw(x)"
601        );
602        assert_eq!(asserts[1].message, "'x' must be NULL or have length 1");
603        assert_eq!(asserts[1].condition, "is.null(x) || length(x) == 1L");
604    }
605
606    #[test]
607    fn nullable_character() {
608        let asserts = assertions_for("Option<String>", "s");
609        assert_eq!(asserts.len(), 2);
610        assert_eq!(asserts[0].message, "'s' must be NULL or character");
611        assert_eq!(asserts[0].condition, "is.null(s) || is.character(s)");
612        assert_eq!(asserts[1].message, "'s' must be NULL or have length 1");
613    }
614
615    #[test]
616    fn map_types() {
617        for ty_str in &["HashMap<String, i32>", "BTreeMap<String, f64>"] {
618            let ty = parse_type(ty_str);
619            let asserts = r_check_for_type(&ty).unwrap().assertions("x");
620            assert_eq!(asserts.len(), 1);
621            assert_eq!(asserts[0].condition, "is.list(x)");
622        }
623    }
624
625    #[test]
626    fn skip_types() {
627        for ty_str in &["SEXP", "ExternalPtr<MyType>"] {
628            let ty = parse_type(ty_str);
629            assert!(
630                r_check_for_type(&ty).is_none(),
631                "{} should be skipped",
632                ty_str
633            );
634        }
635    }
636
637    #[test]
638    fn single_param_produces_multi_line() {
639        // i32 produces 2 assertions → always multi-line now
640        let sig: syn::Signature = syn::parse_str("fn f(n: i32)").unwrap();
641        let output = build_precondition_checks(&sig.inputs, &HashSet::new());
642        let checks = &output.static_checks;
643        assert_eq!(checks.len(), 4); // stopifnot( + 2 args + )
644        assert_eq!(checks[0], "stopifnot(");
645        assert!(checks[1].contains("numeric, logical, or raw"));
646        assert!(checks[2].contains("length 1"));
647        assert_eq!(checks[3], ")");
648        assert!(output.fallback_params.is_empty());
649    }
650
651    #[test]
652    fn vector_param_single_line() {
653        // Vec<f64> produces 1 assertion → single line
654        let sig: syn::Signature = syn::parse_str("fn f(x: Vec<f64>)").unwrap();
655        let output = build_precondition_checks(&sig.inputs, &HashSet::new());
656        let checks = &output.static_checks;
657        assert_eq!(checks.len(), 1);
658        assert!(checks[0].starts_with("stopifnot("));
659        assert!(checks[0].ends_with(')'));
660    }
661
662    #[test]
663    fn two_scalar_params_produces_six_lines() {
664        let sig: syn::Signature = syn::parse_str("fn f(a: i32, b: f64)").unwrap();
665        let output = build_precondition_checks(&sig.inputs, &HashSet::new());
666        let checks = &output.static_checks;
667        // stopifnot( + 4 assertions (2 per param) + )
668        assert_eq!(checks.len(), 6);
669        assert_eq!(checks[0], "stopifnot(");
670        assert!(checks[1].contains("'a'") && checks[1].contains("numeric"));
671        assert!(checks[2].contains("'a'") && checks[2].contains("length 1"));
672        assert!(checks[3].contains("'b'") && checks[3].contains("numeric"));
673        assert!(checks[4].contains("'b'") && checks[4].contains("length 1"));
674        assert_eq!(checks[5], ")");
675    }
676
677    #[test]
678    fn build_checks_skips_match_arg() {
679        let sig: syn::Signature = syn::parse_str("fn f(n: i32, mode: String)").unwrap();
680        let mut skip = HashSet::new();
681        skip.insert("mode".to_string());
682        let output = build_precondition_checks(&sig.inputs, &skip);
683        // Only n's 2 assertions remain
684        let joined = output.static_checks.join("\n");
685        assert!(joined.contains("'n'"));
686        assert!(!joined.contains("'mode'"));
687    }
688
689    #[test]
690    fn unknown_type_produces_fallback() {
691        let sig: syn::Signature = syn::parse_str("fn f(x: MyCustomType)").unwrap();
692        let output = build_precondition_checks(&sig.inputs, &HashSet::new());
693        assert!(output.static_checks.is_empty());
694        assert_eq!(output.fallback_params.len(), 1);
695        assert_eq!(output.fallback_params[0].r_name, "x");
696    }
697
698    #[test]
699    fn mixed_known_and_unknown_types() {
700        let sig: syn::Signature = syn::parse_str("fn f(a: i32, b: MyType, c: String)").unwrap();
701        let output = build_precondition_checks(&sig.inputs, &HashSet::new());
702        // a (i32) and c (String) are known → static checks
703        let joined = output.static_checks.join("\n");
704        assert!(joined.contains("'a'"));
705        assert!(joined.contains("'c'"));
706        assert!(!joined.contains("'b'"));
707        // b (MyType) is unknown → fallback
708        assert_eq!(output.fallback_params.len(), 1);
709        assert_eq!(output.fallback_params[0].r_name, "b");
710    }
711
712    #[test]
713    fn sexp_not_fallback() {
714        let sig: syn::Signature = syn::parse_str("fn f(x: SEXP)").unwrap();
715        let output = build_precondition_checks(&sig.inputs, &HashSet::new());
716        assert!(output.static_checks.is_empty());
717        assert!(output.fallback_params.is_empty());
718    }
719}