tallyingTest.ml 4.74 KB
Newer Older
1
open OUnit
2
open Types
3 4 5 6 7 8 9 10

let parse_typ s =
  let st = Stream.of_string s in
  let astpat = Parser.pat st in 
  let nodepat = Typer.typ Builtin.env astpat in
  Types.descr nodepat
;;

11 12 13 14 15 16
let cup = Tallying.CS.cup
let cap = Tallying.CS.cap
let singleton = Tallying.CS.singleton

let mk_pos (alpha,t) = singleton (Tallying.Pos (`Var alpha,parse_typ t))
let mk_neg (t,alpha) = singleton (Tallying.Neg (parse_typ t,`Var alpha))
17

18
let norm_tests = [
19 20 21 22 23 24 25 26 27 28 29 30
  "(`$A -> Bool)", "(`$B -> `$B)", 
                    cup (mk_pos ("B","Empty"))
                      (cap 
                        (mk_neg ("`$B","A"))
                        (mk_neg ("Bool","B"))
                      );
  "`$B", "`$A", mk_neg ("`$B","A");
  "`$B", "Empty", mk_pos ("B","Empty");
  "Int", "`$B", mk_neg ("Int","B");
  "Int", "(`$A | `$B)", mk_neg ("Int \\ `$B","A");
  "(Int -> Bool)", "(`$A -> `$B)",
                    cup (mk_pos ("A","Empty")) 
31
                      (cap 
32 33
                        (mk_pos ("A","Int"))
                        (mk_neg ("Bool","B"))
34
                      );
35 36
]

37 38
let m_compare = Tallying.CS.M.compare Types.compare

39
module MList = OUnitDiff.ListSimpleMake (struct 
40 41 42 43 44 45
  type t = Tallying.CS.m
  let compare = m_compare
  let pp_printer = Tallying.CS.print_m
  let pp_print_sep = OUnitDiff.pp_comma_separator
end)

46 47
let test_norm =
  "test tallying norm" >:::
48 49 50 51 52 53
    List.map (fun (s,t,expected) ->
      (Printf.sprintf " %s \\ %s" s t) >:: (fun _ ->
        let s = parse_typ s in
        let t = parse_typ t in
        let result = Tallying.norm (diff s t) in
        let elem s = List.sort m_compare (Tallying.CS.S.elements s) in
54
        MList.assert_equal (elem expected) (elem result)
55 56 57 58
      )
    ) norm_tests
;;

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
let merge_tests = [
  [("`$A", "Empty");("`$B", "Empty")], cap (mk_pos ("A", "Empty")) (mk_pos ("B", "Empty"));
  [("`$A", "Int | Bool");("Int","`$B");("`$B", "Empty")], Tallying.CS.unsat;
  [("Bool","`$B"); ("`$B", "`$A"); ("`$A", "Empty")], Tallying.CS.unsat;
  [("Bool","`$B"); ("Int","`$B"); ("`$B","`$A"); ("`$A", "Int | Bool")],
    cap 
      (mk_neg ("`$B","A")) (
        cap 
          (mk_pos ("A", "Int | Bool")) (
            cap
            (mk_neg ("Int | Bool","B"))
            (mk_pos ("B","Int | Bool"))
          )
        );
  [("`$A", "`$B")], mk_pos ("A","`$B");
  [("`$B", "Empty")], mk_pos ("B","Empty");
]

let test_merge =
  let print_test l =
    String.concat ";" (List.map (fun (s,t) -> Printf.sprintf " %s \\ %s" s t) l)
  in
  "test tallying merge" >:::
    List.map (fun (l,expected) ->
      (print_test l) >:: (fun _ ->
        let n = List.fold_left (fun acc (s,t) ->
          let s = parse_typ s in
          let t = parse_typ t in
          Tallying.CS.cap acc (Tallying.norm(diff s t))) Tallying.CS.S.empty l 
        in
        let result = Tallying.CS.S.fold (fun c acc ->
          try cup (Tallying.merge c) acc with Tallying.UnSatConstr -> acc
          ) n Tallying.CS.S.empty
        in
        let elem s = List.sort m_compare (Tallying.CS.S.elements s) in
        MList.assert_equal (elem expected) (elem result)
      )
    ) merge_tests
;;

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
let e_compare = Tallying.CS.E.compare Types.compare

let to_string pp t =
  Format.fprintf Format.str_formatter "%a@." pp t;
  Format.flush_str_formatter ()
;;

let compare_constr (v1,t1) (v2,t2) = 
  if (v1,t1) == (v2,t2) then 0
  else let c = Var.compare v1 v2 in if c <> 0 then c
  else Types.compare t1 t2

module EList = OUnitDiff.ListSimpleMake (struct 
  type t = (Var.var * Types.t)
  let compare = compare_constr
  let pp_printer ppf (`Var v,t) = Format.fprintf ppf "(%s = %s)" v (to_string Print.print t)
  let pp_print_sep = OUnitDiff.pp_comma_separator
end)

module SubList = OUnitDiff.ListSimpleMake (struct 
  type t = EList.t
  let compare a b = EList.compare a b
  let pp_printer ppf l = EList.pp_printer ppf l
  let pp_print_sep = OUnitDiff.pp_comma_separator
end)

let mk_e ll =
  List.map (fun l ->
    List.map (fun (v,t) -> (`Var v),(parse_typ t)) l
  ) ll

let tallying_tests = [
  [("((Int | Bool) -> Int)", "(`$A -> `$B)"); ("(`$A -> Bool)","(`$B -> `$B)")], mk_e [
    [("A","Empty");("B","Empty")];
    [("A","Int | Bool");("B","Int | Bool")]
  ]
]

let test_tallying =
  let print_test l =
    String.concat ";" (List.map (fun (s,t) -> Printf.sprintf " %s \\ %s" s t) l)
  in
  "test tallying merge" >:::
    List.map (fun (l,expected) ->
      (print_test l) >:: (fun _ ->
        let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in
        let result = Tallying.tallying l in
        let elem s = SubList.of_list (List.map (fun l -> EList.of_list (List.sort compare_constr l)) s) in
        SubList.assert_equal (elem expected) (elem result)
      )
    ) tallying_tests
;;

152 153 154
let all =
  "all tests" >::: [
    test_norm;
155
    test_merge;
156
    test_tallying;
157 158 159 160 161 162 163 164
  ]
 
let main () =
  OUnit.run_test_tt_main all
;;
 
main ()