Commit e66b9dc7 authored by Pietro Abate's avatar Pietro Abate
Browse files

Add unit test for merge function

- separate the merge and solve function in the Tallying module
parent 7f614b4d
......@@ -36,7 +36,7 @@ let norm_tests = [
let m_compare = Tallying.CS.M.compare Types.compare
module ECS = OUnitDiff.ListSimpleMake (struct
module MList = OUnitDiff.ListSimpleMake (struct
type t = Tallying.CS.m
let compare = m_compare
let pp_printer = Tallying.CS.print_m
......@@ -51,14 +51,69 @@ let test_norm =
let t = parse_typ t in
let result = Tallying.norm (diff s t) in
let elem s = List.sort m_compare (Tallying.CS.S.elements s) in
ECS.assert_equal (elem expected) (elem result)
MList.assert_equal (elem expected) (elem result)
)
) norm_tests
;;
let mk_e l =
List.fold_left (fun acc (v,t) ->
Tallying.CS.E.add (`Var v) (parse_typ t) acc
) Tallying.CS.E.empty l
let merge_tests = [
[("`$A", "Empty");("`$B", "Empty")], cap (mk_pos ("A", "Empty")) (mk_pos ("B", "Empty"));
[("`$A", "Int | Bool");("Int","`$B");("`$B", "Empty")], Tallying.CS.unsat;
[("Bool","`$B"); ("`$B", "`$A"); ("`$A", "Empty")], Tallying.CS.unsat;
[("Bool","`$B"); ("Int","`$B"); ("`$B","`$A"); ("`$A", "Int | Bool")],
cap
(mk_neg ("`$B","A")) (
cap
(mk_pos ("A", "Int | Bool")) (
cap
(mk_neg ("Int | Bool","B"))
(mk_pos ("B","Int | Bool"))
)
);
[("`$A", "`$B")], mk_pos ("A","`$B");
[("`$B", "Empty")], mk_pos ("B","Empty");
]
let e_compare = Tallying.CS.E.compare Types.compare
module EList = OUnitDiff.ListSimpleMake (struct
type t = Tallying.CS.e
let compare = e_compare
let pp_printer = Tallying.CS.print_e
let pp_print_sep = OUnitDiff.pp_comma_separator
end)
let test_merge =
let print_test l =
String.concat ";" (List.map (fun (s,t) -> Printf.sprintf " %s \\ %s" s t) l)
in
"test tallying merge" >:::
List.map (fun (l,expected) ->
(print_test l) >:: (fun _ ->
let n = List.fold_left (fun acc (s,t) ->
let s = parse_typ s in
let t = parse_typ t in
Tallying.CS.cap acc (Tallying.norm(diff s t))) Tallying.CS.S.empty l
in
let result = Tallying.CS.S.fold (fun c acc ->
try cup (Tallying.merge c) acc with Tallying.UnSatConstr -> acc
) n Tallying.CS.S.empty
in
let elem s = List.sort m_compare (Tallying.CS.S.elements s) in
MList.assert_equal (elem expected) (elem result)
)
) merge_tests
;;
let all =
"all tests" >::: [
test_norm;
test_merge;
]
let main () =
......
......@@ -2284,9 +2284,9 @@ module Tallying = struct
let print ppf m =
print_lst (fun ppf -> fun ((b,`Var v),s) ->
if b then
Format.fprintf ppf "(`$%s,%a)" v dump s
Format.fprintf ppf "(`$%s,%a)" v Print.print s
else
Format.fprintf ppf "(%a,`$%s)" dump s v
Format.fprintf ppf "(%a,`$%s)" Print.print s v
) ppf (bindings m);
end
......@@ -2546,13 +2546,14 @@ module Tallying = struct
in
if CS.S.is_empty mm then CS.S.singleton m else mm
(* merge and solve *)
let merge m =
let merge m = merge (m,DescrSet.empty)
let solve s =
let aux v (s,t) acc =
let b = var (Var.fresh ()) in
CS.E.add v (cap (cup s b) t) acc
in
let solve m =
let aux m =
CS.M.fold (fun (b,v) s acc ->
try
let t = CS.M.find (not b,v) m in
......@@ -2566,9 +2567,7 @@ module Tallying = struct
if b then aux v (any,s) acc else aux v (s,empty) acc
) m CS.E.empty
in
try
let l = merge (m,DescrSet.empty) in
CS.S.fold (fun m acc -> (solve m)::acc) l []
try CS.S.fold (fun m acc -> (aux m)::acc) s []
with UnSatConstr -> []
let unify e =
......@@ -2586,7 +2585,7 @@ module Tallying = struct
let tallying l =
let n = List.fold_left (fun acc (s,t) -> CS.cap acc (norm(diff s t))) CS.S.empty l in
let m = CS.S.fold (fun c acc -> acc @ (merge (c))) n [] in
let m = CS.S.fold (fun c acc -> acc @ (solve (merge c))) n [] in
List.fold_left (fun acc e -> (unify e)::acc) [] m
end
......@@ -355,6 +355,8 @@ module Tallying : sig
|Pos of (Var.var * t) (** alpha <= t | alpha \in P *)
|Neg of (t * Var.var) (** t <= alpha | alpha \in N *)
exception UnSatConstr
module CS : sig
module M : sig
include Map.S with type key = (bool * Var.var)
......@@ -386,7 +388,8 @@ module Tallying : sig
end
val norm : t -> CS.s
val merge : CS.m -> CS.e list
val merge : CS.m -> CS.s
val solve : CS.s -> CS.e list
val unify : CS.e -> CS.e
val tallying : (t * t) list -> CS.e list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment