Commit 11c9e701 authored by Pietro Abate's avatar Pietro Abate
Browse files

Add final unit test for the Tallying Algorithm

- More api changes
parent e66b9dc7
......@@ -56,11 +56,6 @@ let test_norm =
) norm_tests
;;
let mk_e l =
List.fold_left (fun acc (v,t) ->
Tallying.CS.E.add (`Var v) (parse_typ t) acc
) Tallying.CS.E.empty l
let merge_tests = [
[("`$A", "Empty");("`$B", "Empty")], cap (mk_pos ("A", "Empty")) (mk_pos ("B", "Empty"));
[("`$A", "Int | Bool");("Int","`$B");("`$B", "Empty")], Tallying.CS.unsat;
......@@ -79,15 +74,6 @@ let merge_tests = [
[("`$B", "Empty")], mk_pos ("B","Empty");
]
let e_compare = Tallying.CS.E.compare Types.compare
module EList = OUnitDiff.ListSimpleMake (struct
type t = Tallying.CS.e
let compare = e_compare
let pp_printer = Tallying.CS.print_e
let pp_print_sep = OUnitDiff.pp_comma_separator
end)
let test_merge =
let print_test l =
String.concat ";" (List.map (fun (s,t) -> Printf.sprintf " %s \\ %s" s t) l)
......@@ -110,10 +96,64 @@ let test_merge =
) merge_tests
;;
let e_compare = Tallying.CS.E.compare Types.compare
let to_string pp t =
Format.fprintf Format.str_formatter "%a@." pp t;
Format.flush_str_formatter ()
;;
let compare_constr (v1,t1) (v2,t2) =
if (v1,t1) == (v2,t2) then 0
else let c = Var.compare v1 v2 in if c <> 0 then c
else Types.compare t1 t2
module EList = OUnitDiff.ListSimpleMake (struct
type t = (Var.var * Types.t)
let compare = compare_constr
let pp_printer ppf (`Var v,t) = Format.fprintf ppf "(%s = %s)" v (to_string Print.print t)
let pp_print_sep = OUnitDiff.pp_comma_separator
end)
module SubList = OUnitDiff.ListSimpleMake (struct
type t = EList.t
let compare a b = EList.compare a b
let pp_printer ppf l = EList.pp_printer ppf l
let pp_print_sep = OUnitDiff.pp_comma_separator
end)
let mk_e ll =
List.map (fun l ->
List.map (fun (v,t) -> (`Var v),(parse_typ t)) l
) ll
let tallying_tests = [
[("((Int | Bool) -> Int)", "(`$A -> `$B)"); ("(`$A -> Bool)","(`$B -> `$B)")], mk_e [
[("A","Empty");("B","Empty")];
[("A","Int | Bool");("B","Int | Bool")]
]
]
let test_tallying =
let print_test l =
String.concat ";" (List.map (fun (s,t) -> Printf.sprintf " %s \\ %s" s t) l)
in
"test tallying merge" >:::
List.map (fun (l,expected) ->
(print_test l) >:: (fun _ ->
let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in
let result = Tallying.tallying l in
let elem s = SubList.of_list (List.map (fun l -> EList.of_list (List.sort compare_constr l)) s) in
SubList.assert_equal (elem expected) (elem result)
)
) tallying_tests
;;
let all =
"all tests" >::: [
test_norm;
test_merge;
test_tallying;
]
let main () =
......
......@@ -520,8 +520,8 @@ let get_record r =
in
List.map line (Rec.get r)
(* substitute all occurrences of v in t by s *)
let rec subst v (t,s) =
(* substitute in t all occurrences of v by s *)
let rec subst t (v,s) =
let module C ( X : BoolVar.S ) =
struct
let atom_aux ?f v (s,ss) =
......@@ -545,8 +545,8 @@ let rec subst v (t,s) =
List.fold_left (fun acc (left,rigth) ->
let deep_subst l =
List.fold_left (fun acc (t1,t2) ->
let d1 = cons (subst v (descr t1,s)) in
let d2 = cons (subst v (descr t2,s)) in
let d1 = cons (subst (descr t1) (v,s)) in
let d2 = cons (subst (descr t2) (v,s)) in
BoolPair.cap acc (BoolPair.atom (`Atm (Pair.atom (d1,d2))))
) BoolPair.full l
in
......@@ -2567,25 +2567,24 @@ module Tallying = struct
if b then aux v (any,s) acc else aux v (s,empty) acc
) m CS.E.empty
in
try CS.S.fold (fun m acc -> (aux m)::acc) s []
with UnSatConstr -> []
CS.S.fold (fun m acc -> (aux m)::acc) s []
let unify e =
let rec aux acc e =
if CS.E.is_empty e then acc
let rec aux (sol,acc) e =
if CS.E.is_empty e then sol
else
let (alpha,t) = CS.E.min_binding e in
let (alpha,t) = CS.E.max_binding e in
let e1 = CS.E.remove alpha e in
let x = Var.fresh () in
(* XXX ... let x = Var.fresh () in *)
(* replace in e1 all occurrences of a by ... *)
let es = CS.E.fold (fun beta t acc -> CS.E.add beta (subst alpha (t,var x)) acc) e1 CS.E.empty in
aux (CS.E.add alpha (subst alpha (t,var x)) acc) es
let es = CS.E.fold (fun beta s acc -> CS.E.add beta (subst s (alpha,t)) acc) e1 CS.E.empty in
aux (((alpha,t)::sol),(CS.E.add alpha (subst t (alpha,t)) acc)) es
in
aux CS.E.empty e
aux ([],CS.E.empty) e
let tallying l =
let n = List.fold_left (fun acc (s,t) -> CS.cap acc (norm(diff s t))) CS.S.empty l in
let m = CS.S.fold (fun c acc -> acc @ (solve (merge c))) n [] in
let m = CS.S.fold (fun c acc -> try acc @ (solve (merge c)) with UnSatConstr -> acc) n [] in
List.fold_left (fun acc e -> (unify e)::acc) [] m
end
......@@ -142,7 +142,7 @@ val rec_of_list: bool -> (bool * Ns.Label.t * t) list -> t
val empty_closed_record: t
val empty_open_record: t
val subst : Var.var -> t * t -> t
val subst : t -> Var.var * t -> t
(** Positive systems and least solutions **)
......@@ -390,8 +390,8 @@ module Tallying : sig
val norm : t -> CS.s
val merge : CS.m -> CS.s
val solve : CS.s -> CS.e list
val unify : CS.e -> CS.e
val tallying : (t * t) list -> CS.e list
val unify : CS.e -> (Var.var * t) list
val tallying : (t * t) list -> (Var.var * t) list list
end
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment