1use std::collections::HashSet;
22
23struct RAssertion {
28 message: String,
30 condition: String,
32}
33
34impl RAssertion {
35 fn new(message: impl Into<String>, condition: impl Into<String>) -> Self {
37 Self {
38 message: message.into(),
39 condition: condition.into(),
40 }
41 }
42
43 fn to_stopifnot_arg(&self) -> String {
45 format!("\"{}\" = {}", self.message, self.condition)
46 }
47
48 fn nullable(self, param: &str) -> Self {
51 let message = if self.message.contains("must be ") {
52 self.message.replacen("must be ", "must be NULL or ", 1)
54 } else if self.message.contains("must have ") {
55 self.message
57 .replacen("must have ", "must be NULL or have ", 1)
58 } else {
59 format!("{} (or NULL)", self.message)
60 };
61 Self {
62 message,
63 condition: format!("is.null({}) || {}", param, self.condition),
64 }
65 }
66}
67
68enum RTypeCheck {
76 ScalarNumeric,
79 ScalarNonNeg,
82 Scalar(&'static str),
85 VectorNumeric,
87 Vector(&'static str),
90 Nullable(Box<RTypeCheck>),
93 List,
96}
97
98fn numeric_type_check(param: &str) -> String {
104 format!(
105 "is.numeric({p}) || is.logical({p}) || is.raw({p})",
106 p = param
107 )
108}
109
110impl RTypeCheck {
111 fn assertions(&self, param: &str) -> Vec<RAssertion> {
117 match self {
118 RTypeCheck::ScalarNumeric => vec![
119 RAssertion::new(
120 format!("'{}' must be numeric, logical, or raw", param),
121 numeric_type_check(param),
122 ),
123 RAssertion::new(
124 format!("'{}' must have length 1", param),
125 format!("length({}) == 1L", param),
126 ),
127 ],
128 RTypeCheck::ScalarNonNeg => vec![
129 RAssertion::new(
130 format!("'{}' must be numeric, logical, or raw", param),
131 numeric_type_check(param),
132 ),
133 RAssertion::new(
134 format!("'{}' must have length 1", param),
135 format!("length({}) == 1L", param),
136 ),
137 RAssertion::new(
138 format!("'{}' must be non-negative", param),
139 format!("is.raw({p}) || {p} >= 0", p = param),
142 ),
143 ],
144 RTypeCheck::Scalar(r_type) => vec![
145 RAssertion::new(
146 format!("'{}' must be {}", param, r_type),
147 format!("is.{}({})", r_type, param),
148 ),
149 RAssertion::new(
150 format!("'{}' must have length 1", param),
151 format!("length({}) == 1L", param),
152 ),
153 ],
154 RTypeCheck::VectorNumeric => vec![RAssertion::new(
155 format!("'{}' must be numeric, logical, or raw", param),
156 numeric_type_check(param),
157 )],
158 RTypeCheck::Vector(r_type) => vec![RAssertion::new(
159 format!("'{}' must be {}", param, r_type),
160 format!("is.{}({})", r_type, param),
161 )],
162 RTypeCheck::Nullable(inner) => inner
163 .assertions(param)
164 .into_iter()
165 .map(|a| a.nullable(param))
166 .collect(),
167 RTypeCheck::List => vec![RAssertion::new(
168 format!("'{}' must be a list", param),
169 format!("is.list({})", param),
170 )],
171 }
172 }
173}
174
175fn r_check_for_type(ty: &syn::Type) -> Option<RTypeCheck> {
179 match ty {
180 syn::Type::Path(type_path) => r_check_for_type_path(type_path),
181 syn::Type::Reference(type_ref) => r_check_for_reference(type_ref),
182 _ => None,
183 }
184}
185
186fn r_check_for_type_path(type_path: &syn::TypePath) -> Option<RTypeCheck> {
192 let segment = type_path.path.segments.last()?;
193 let ident = segment.ident.to_string();
194
195 match ident.as_str() {
196 "i32" | "f64" | "f32" | "i8" | "i16" | "i64" | "isize" => Some(RTypeCheck::ScalarNumeric),
198
199 "u16" | "u32" | "u64" | "usize" => Some(RTypeCheck::ScalarNonNeg),
201
202 "bool" | "Rbool" | "Rboolean" => Some(RTypeCheck::Scalar("logical")),
204
205 "String" | "char" | "PathBuf" => Some(RTypeCheck::Scalar("character")),
207
208 "u8" => Some(RTypeCheck::Scalar("raw")),
210
211 "Rcomplex" => Some(RTypeCheck::Scalar("complex")),
213
214 "Option" => {
216 let inner_ty = extract_single_generic_arg(segment)?;
217 r_check_for_type(inner_ty).map(|inner| RTypeCheck::Nullable(Box::new(inner)))
218 }
219
220 "Vec" => {
222 let inner_ty = extract_single_generic_arg(segment)?;
223 r_check_for_vec_element(inner_ty)
224 }
225
226 "HashMap" | "BTreeMap" | "NamedList" => Some(RTypeCheck::List),
228
229 "List" | "ListMut" => Some(RTypeCheck::List),
231
232 "SEXP" | "Dots" | "Missing" | "ExternalPtr" | "OwnedProtect" => None,
234
235 _ => None,
237 }
238}
239
240fn r_check_for_reference(type_ref: &syn::TypeReference) -> Option<RTypeCheck> {
245 match type_ref.elem.as_ref() {
246 syn::Type::Path(tp) => {
248 let seg = tp.path.segments.last()?;
249 match seg.ident.to_string().as_str() {
250 "str" => Some(RTypeCheck::Scalar("character")),
251 "Path" => Some(RTypeCheck::Scalar("character")),
252 "Dots" => None,
253 _ => None,
254 }
255 }
256 syn::Type::Slice(slice) => r_check_for_vec_element(&slice.elem),
258 _ => None,
259 }
260}
261
262fn r_check_for_vec_element(elem_ty: &syn::Type) -> Option<RTypeCheck> {
268 let syn::Type::Path(tp) = elem_ty else {
269 return None;
270 };
271 let seg = tp.path.segments.last()?;
272 let ident = seg.ident.to_string();
273
274 match ident.as_str() {
275 "i32" | "f64" | "f32" | "i8" | "i16" | "u16" | "u32" | "i64" | "u64" | "isize"
277 | "usize" => Some(RTypeCheck::VectorNumeric),
278
279 "bool" => Some(RTypeCheck::Vector("logical")),
281
282 "String" => Some(RTypeCheck::Vector("character")),
284
285 "u8" => Some(RTypeCheck::Vector("raw")),
287
288 "Rcomplex" => Some(RTypeCheck::Vector("complex")),
290
291 "Option" => {
293 let inner = extract_single_generic_arg(seg)?;
294 r_check_for_vec_element(inner)
296 }
297
298 _ => None,
299 }
300}
301
302fn extract_single_generic_arg(segment: &syn::PathSegment) -> Option<&syn::Type> {
306 if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments
307 && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
308 {
309 return Some(ty);
310 }
311 None
312}
313
314#[allow(dead_code)] pub struct FallbackParam {
320 pub r_name: String,
322}
323
324pub struct PreconditionOutput {
329 pub static_checks: Vec<String>,
335 #[allow(dead_code)] pub fallback_params: Vec<FallbackParam>,
338}
339
340fn is_skip_type(ident: &str) -> bool {
346 matches!(
347 ident,
348 "SEXP" | "Dots" | "Missing" | "ExternalPtr" | "OwnedProtect"
349 )
350}
351
352fn needs_fallback(ty: &syn::Type) -> bool {
358 match ty {
359 syn::Type::Path(tp) => {
360 let Some(seg) = tp.path.segments.last() else {
361 return false;
362 };
363 !is_skip_type(&seg.ident.to_string())
364 }
365 syn::Type::Reference(_) => false,
367 _ => false,
368 }
369}
370
371pub fn build_precondition_checks(
390 inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]>,
391 skip_params: &HashSet<String>,
392) -> PreconditionOutput {
393 let mut args = Vec::new();
394 let mut fallback_params = Vec::new();
395
396 for arg in inputs {
397 let syn::FnArg::Typed(pt) = arg else {
399 continue;
400 };
401
402 let syn::Pat::Ident(pat_ident) = pt.pat.as_ref() else {
404 continue;
405 };
406
407 let r_name = crate::r_wrapper_builder::normalize_r_arg_ident(&pat_ident.ident).to_string();
409
410 if skip_params.contains(&r_name) {
412 continue;
413 }
414
415 if let Some(check) = r_check_for_type(pt.ty.as_ref()) {
417 for assertion in check.assertions(&r_name) {
418 args.push(assertion.to_stopifnot_arg());
419 }
420 } else if needs_fallback(pt.ty.as_ref()) {
421 fallback_params.push(FallbackParam { r_name });
423 }
424 }
425
426 let static_checks = match args.len() {
427 0 => Vec::new(),
428 1 => vec![format!("stopifnot({})", args[0])],
429 _ => {
430 let mut lines = Vec::with_capacity(args.len() + 2);
431 lines.push("stopifnot(".to_string());
432 for (i, arg) in args.iter().enumerate() {
433 let comma = if i < args.len() - 1 { "," } else { "" };
434 lines.push(format!(" {}{}", arg, comma));
435 }
436 lines.push(")".to_string());
437 lines
438 }
439 };
440
441 PreconditionOutput {
442 static_checks,
443 fallback_params,
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 fn parse_type(s: &str) -> syn::Type {
453 syn::parse_str(s).unwrap()
454 }
455
456 fn assertions_for(ty_str: &str, param: &str) -> Vec<RAssertion> {
458 let ty = parse_type(ty_str);
459 r_check_for_type(&ty).unwrap().assertions(param)
460 }
461
462 #[test]
463 fn scalar_numeric_produces_two_assertions() {
464 let asserts = assertions_for("i32", "x");
465 assert_eq!(asserts.len(), 2);
466 assert_eq!(asserts[0].message, "'x' must be numeric, logical, or raw");
467 assert_eq!(
468 asserts[0].condition,
469 "is.numeric(x) || is.logical(x) || is.raw(x)"
470 );
471 assert_eq!(asserts[1].message, "'x' must have length 1");
472 assert_eq!(asserts[1].condition, "length(x) == 1L");
473 }
474
475 #[test]
476 fn all_signed_numeric_types_use_scalar_numeric() {
477 for ty_str in &["i32", "f64", "f32", "i8", "i16", "i64", "isize"] {
478 let asserts = assertions_for(ty_str, "x");
479 assert_eq!(asserts.len(), 2, "{} should produce 2 assertions", ty_str);
480 assert!(
481 asserts[0].condition.contains("is.numeric(x)"),
482 "{} type check",
483 ty_str
484 );
485 assert!(
486 asserts[0].condition.contains("is.logical(x)"),
487 "{} accepts logical",
488 ty_str
489 );
490 assert!(
491 asserts[0].condition.contains("is.raw(x)"),
492 "{} accepts raw",
493 ty_str
494 );
495 }
496 }
497
498 #[test]
499 fn scalar_non_neg_produces_three_assertions() {
500 let asserts = assertions_for("u32", "n");
501 assert_eq!(asserts.len(), 3);
502 assert_eq!(asserts[0].message, "'n' must be numeric, logical, or raw");
503 assert_eq!(asserts[1].message, "'n' must have length 1");
504 assert_eq!(asserts[2].message, "'n' must be non-negative");
505 assert_eq!(asserts[2].condition, "is.raw(n) || n >= 0");
506 }
507
508 #[test]
509 fn all_unsigned_types_use_scalar_non_neg() {
510 for ty_str in &["u16", "u32", "u64", "usize"] {
511 let asserts = assertions_for(ty_str, "x");
512 assert_eq!(asserts.len(), 3, "{} should produce 3 assertions", ty_str);
513 assert!(
514 asserts[2].condition.contains(">= 0"),
515 "{} non-neg check",
516 ty_str
517 );
518 }
519 }
520
521 #[test]
522 fn scalar_logical() {
523 let asserts = assertions_for("bool", "x");
524 assert_eq!(asserts.len(), 2);
525 assert_eq!(asserts[0].message, "'x' must be logical");
526 assert_eq!(asserts[0].condition, "is.logical(x)");
527 assert_eq!(asserts[1].condition, "length(x) == 1L");
528 }
529
530 #[test]
531 fn scalar_character() {
532 for ty_str in &["String", "char", "PathBuf"] {
533 let asserts = assertions_for(ty_str, "s");
534 assert_eq!(asserts.len(), 2);
535 assert_eq!(asserts[0].message, "'s' must be character");
536 assert_eq!(asserts[0].condition, "is.character(s)");
537 }
538 }
539
540 #[test]
541 fn ref_str() {
542 let ty: syn::Type = syn::parse_str("& str").unwrap();
543 let asserts = r_check_for_type(&ty).unwrap().assertions("s");
544 assert_eq!(asserts.len(), 2);
545 assert_eq!(asserts[0].condition, "is.character(s)");
546 }
547
548 #[test]
549 fn scalar_raw() {
550 let asserts = assertions_for("u8", "x");
551 assert_eq!(asserts.len(), 2);
552 assert_eq!(asserts[0].message, "'x' must be raw");
553 assert_eq!(asserts[0].condition, "is.raw(x)");
554 }
555
556 #[test]
557 fn vector_numeric_produces_one_assertion() {
558 for ty_str in &["Vec<f64>", "Vec<i8>", "Vec<i32>", "Vec<i64>"] {
559 let asserts = assertions_for(ty_str, "x");
560 assert_eq!(asserts.len(), 1, "{} should produce 1 assertion", ty_str);
561 assert_eq!(
562 asserts[0].condition,
563 "is.numeric(x) || is.logical(x) || is.raw(x)"
564 );
565 }
566 }
567
568 #[test]
569 fn vector_character() {
570 let asserts = assertions_for("Vec<String>", "x");
571 assert_eq!(asserts.len(), 1);
572 assert_eq!(asserts[0].condition, "is.character(x)");
573 }
574
575 #[test]
576 fn vector_optional_string() {
577 let asserts = assertions_for("Vec<Option<String>>", "x");
578 assert_eq!(asserts.len(), 1);
579 assert_eq!(asserts[0].condition, "is.character(x)");
580 }
581
582 #[test]
583 fn slice_u8() {
584 let ty: syn::Type = syn::parse_str("& [u8]").unwrap();
585 let asserts = r_check_for_type(&ty).unwrap().assertions("x");
586 assert_eq!(asserts.len(), 1);
587 assert_eq!(asserts[0].condition, "is.raw(x)");
588 }
589
590 #[test]
591 fn nullable_wraps_inner_assertions() {
592 let asserts = assertions_for("Option<i32>", "x");
593 assert_eq!(asserts.len(), 2);
594 assert_eq!(
595 asserts[0].message,
596 "'x' must be NULL or numeric, logical, or raw"
597 );
598 assert_eq!(
599 asserts[0].condition,
600 "is.null(x) || is.numeric(x) || is.logical(x) || is.raw(x)"
601 );
602 assert_eq!(asserts[1].message, "'x' must be NULL or have length 1");
603 assert_eq!(asserts[1].condition, "is.null(x) || length(x) == 1L");
604 }
605
606 #[test]
607 fn nullable_character() {
608 let asserts = assertions_for("Option<String>", "s");
609 assert_eq!(asserts.len(), 2);
610 assert_eq!(asserts[0].message, "'s' must be NULL or character");
611 assert_eq!(asserts[0].condition, "is.null(s) || is.character(s)");
612 assert_eq!(asserts[1].message, "'s' must be NULL or have length 1");
613 }
614
615 #[test]
616 fn map_types() {
617 for ty_str in &["HashMap<String, i32>", "BTreeMap<String, f64>"] {
618 let ty = parse_type(ty_str);
619 let asserts = r_check_for_type(&ty).unwrap().assertions("x");
620 assert_eq!(asserts.len(), 1);
621 assert_eq!(asserts[0].condition, "is.list(x)");
622 }
623 }
624
625 #[test]
626 fn skip_types() {
627 for ty_str in &["SEXP", "ExternalPtr<MyType>"] {
628 let ty = parse_type(ty_str);
629 assert!(
630 r_check_for_type(&ty).is_none(),
631 "{} should be skipped",
632 ty_str
633 );
634 }
635 }
636
637 #[test]
638 fn single_param_produces_multi_line() {
639 let sig: syn::Signature = syn::parse_str("fn f(n: i32)").unwrap();
641 let output = build_precondition_checks(&sig.inputs, &HashSet::new());
642 let checks = &output.static_checks;
643 assert_eq!(checks.len(), 4); assert_eq!(checks[0], "stopifnot(");
645 assert!(checks[1].contains("numeric, logical, or raw"));
646 assert!(checks[2].contains("length 1"));
647 assert_eq!(checks[3], ")");
648 assert!(output.fallback_params.is_empty());
649 }
650
651 #[test]
652 fn vector_param_single_line() {
653 let sig: syn::Signature = syn::parse_str("fn f(x: Vec<f64>)").unwrap();
655 let output = build_precondition_checks(&sig.inputs, &HashSet::new());
656 let checks = &output.static_checks;
657 assert_eq!(checks.len(), 1);
658 assert!(checks[0].starts_with("stopifnot("));
659 assert!(checks[0].ends_with(')'));
660 }
661
662 #[test]
663 fn two_scalar_params_produces_six_lines() {
664 let sig: syn::Signature = syn::parse_str("fn f(a: i32, b: f64)").unwrap();
665 let output = build_precondition_checks(&sig.inputs, &HashSet::new());
666 let checks = &output.static_checks;
667 assert_eq!(checks.len(), 6);
669 assert_eq!(checks[0], "stopifnot(");
670 assert!(checks[1].contains("'a'") && checks[1].contains("numeric"));
671 assert!(checks[2].contains("'a'") && checks[2].contains("length 1"));
672 assert!(checks[3].contains("'b'") && checks[3].contains("numeric"));
673 assert!(checks[4].contains("'b'") && checks[4].contains("length 1"));
674 assert_eq!(checks[5], ")");
675 }
676
677 #[test]
678 fn build_checks_skips_match_arg() {
679 let sig: syn::Signature = syn::parse_str("fn f(n: i32, mode: String)").unwrap();
680 let mut skip = HashSet::new();
681 skip.insert("mode".to_string());
682 let output = build_precondition_checks(&sig.inputs, &skip);
683 let joined = output.static_checks.join("\n");
685 assert!(joined.contains("'n'"));
686 assert!(!joined.contains("'mode'"));
687 }
688
689 #[test]
690 fn unknown_type_produces_fallback() {
691 let sig: syn::Signature = syn::parse_str("fn f(x: MyCustomType)").unwrap();
692 let output = build_precondition_checks(&sig.inputs, &HashSet::new());
693 assert!(output.static_checks.is_empty());
694 assert_eq!(output.fallback_params.len(), 1);
695 assert_eq!(output.fallback_params[0].r_name, "x");
696 }
697
698 #[test]
699 fn mixed_known_and_unknown_types() {
700 let sig: syn::Signature = syn::parse_str("fn f(a: i32, b: MyType, c: String)").unwrap();
701 let output = build_precondition_checks(&sig.inputs, &HashSet::new());
702 let joined = output.static_checks.join("\n");
704 assert!(joined.contains("'a'"));
705 assert!(joined.contains("'c'"));
706 assert!(!joined.contains("'b'"));
707 assert_eq!(output.fallback_params.len(), 1);
709 assert_eq!(output.fallback_params[0].r_name, "b");
710 }
711
712 #[test]
713 fn sexp_not_fallback() {
714 let sig: syn::Signature = syn::parse_str("fn f(x: SEXP)").unwrap();
715 let output = build_precondition_checks(&sig.inputs, &HashSet::new());
716 assert!(output.static_checks.is_empty());
717 assert!(output.fallback_params.is_empty());
718 }
719}