Commit e66b9dc7 authored by Pietro Abate's avatar Pietro Abate
Browse files

Add unit test for merge function

- separate the merge and solve function in the Tallying module
parent 7f614b4d
...@@ -36,7 +36,7 @@ let norm_tests = [ ...@@ -36,7 +36,7 @@ let norm_tests = [
let m_compare = Tallying.CS.M.compare Types.compare let m_compare = Tallying.CS.M.compare Types.compare
module ECS = OUnitDiff.ListSimpleMake (struct module MList = OUnitDiff.ListSimpleMake (struct
type t = Tallying.CS.m type t = Tallying.CS.m
let compare = m_compare let compare = m_compare
let pp_printer = Tallying.CS.print_m let pp_printer = Tallying.CS.print_m
...@@ -51,14 +51,69 @@ let test_norm = ...@@ -51,14 +51,69 @@ let test_norm =
let t = parse_typ t in let t = parse_typ t in
let result = Tallying.norm (diff s t) in let result = Tallying.norm (diff s t) in
let elem s = List.sort m_compare (Tallying.CS.S.elements s) in let elem s = List.sort m_compare (Tallying.CS.S.elements s) in
ECS.assert_equal (elem expected) (elem result) MList.assert_equal (elem expected) (elem result)
) )
) norm_tests ) norm_tests
;; ;;
let mk_e l =
List.fold_left (fun acc (v,t) ->
Tallying.CS.E.add (`Var v) (parse_typ t) acc
) Tallying.CS.E.empty l
let merge_tests = [
[("`$A", "Empty");("`$B", "Empty")], cap (mk_pos ("A", "Empty")) (mk_pos ("B", "Empty"));
[("`$A", "Int | Bool");("Int","`$B");("`$B", "Empty")], Tallying.CS.unsat;
[("Bool","`$B"); ("`$B", "`$A"); ("`$A", "Empty")], Tallying.CS.unsat;
[("Bool","`$B"); ("Int","`$B"); ("`$B","`$A"); ("`$A", "Int | Bool")],
cap
(mk_neg ("`$B","A")) (
cap
(mk_pos ("A", "Int | Bool")) (
cap
(mk_neg ("Int | Bool","B"))
(mk_pos ("B","Int | Bool"))
)
);
[("`$A", "`$B")], mk_pos ("A","`$B");
[("`$B", "Empty")], mk_pos ("B","Empty");
]
let e_compare = Tallying.CS.E.compare Types.compare
module EList = OUnitDiff.ListSimpleMake (struct
type t = Tallying.CS.e
let compare = e_compare
let pp_printer = Tallying.CS.print_e
let pp_print_sep = OUnitDiff.pp_comma_separator
end)
let test_merge =
let print_test l =
String.concat ";" (List.map (fun (s,t) -> Printf.sprintf " %s \\ %s" s t) l)
in
"test tallying merge" >:::
List.map (fun (l,expected) ->
(print_test l) >:: (fun _ ->
let n = List.fold_left (fun acc (s,t) ->
let s = parse_typ s in
let t = parse_typ t in
Tallying.CS.cap acc (Tallying.norm(diff s t))) Tallying.CS.S.empty l
in
let result = Tallying.CS.S.fold (fun c acc ->
try cup (Tallying.merge c) acc with Tallying.UnSatConstr -> acc
) n Tallying.CS.S.empty
in
let elem s = List.sort m_compare (Tallying.CS.S.elements s) in
MList.assert_equal (elem expected) (elem result)
)
) merge_tests
;;
let all = let all =
"all tests" >::: [ "all tests" >::: [
test_norm; test_norm;
test_merge;
] ]
let main () = let main () =
......
...@@ -2284,9 +2284,9 @@ module Tallying = struct ...@@ -2284,9 +2284,9 @@ module Tallying = struct
let print ppf m = let print ppf m =
print_lst (fun ppf -> fun ((b,`Var v),s) -> print_lst (fun ppf -> fun ((b,`Var v),s) ->
if b then if b then
Format.fprintf ppf "(`$%s,%a)" v dump s Format.fprintf ppf "(`$%s,%a)" v Print.print s
else else
Format.fprintf ppf "(%a,`$%s)" dump s v Format.fprintf ppf "(%a,`$%s)" Print.print s v
) ppf (bindings m); ) ppf (bindings m);
end end
...@@ -2546,13 +2546,14 @@ module Tallying = struct ...@@ -2546,13 +2546,14 @@ module Tallying = struct
in in
if CS.S.is_empty mm then CS.S.singleton m else mm if CS.S.is_empty mm then CS.S.singleton m else mm
(* merge and solve *) let merge m = merge (m,DescrSet.empty)
let merge m =
let solve s =
let aux v (s,t) acc = let aux v (s,t) acc =
let b = var (Var.fresh ()) in let b = var (Var.fresh ()) in
CS.E.add v (cap (cup s b) t) acc CS.E.add v (cap (cup s b) t) acc
in in
let solve m = let aux m =
CS.M.fold (fun (b,v) s acc -> CS.M.fold (fun (b,v) s acc ->
try try
let t = CS.M.find (not b,v) m in let t = CS.M.find (not b,v) m in
...@@ -2566,9 +2567,7 @@ module Tallying = struct ...@@ -2566,9 +2567,7 @@ module Tallying = struct
if b then aux v (any,s) acc else aux v (s,empty) acc if b then aux v (any,s) acc else aux v (s,empty) acc
) m CS.E.empty ) m CS.E.empty
in in
try try CS.S.fold (fun m acc -> (aux m)::acc) s []
let l = merge (m,DescrSet.empty) in
CS.S.fold (fun m acc -> (solve m)::acc) l []
with UnSatConstr -> [] with UnSatConstr -> []
let unify e = let unify e =
...@@ -2586,7 +2585,7 @@ module Tallying = struct ...@@ -2586,7 +2585,7 @@ module Tallying = struct
let tallying l = let tallying l =
let n = List.fold_left (fun acc (s,t) -> CS.cap acc (norm(diff s t))) CS.S.empty l in let n = List.fold_left (fun acc (s,t) -> CS.cap acc (norm(diff s t))) CS.S.empty l in
let m = CS.S.fold (fun c acc -> acc @ (merge (c))) n [] in let m = CS.S.fold (fun c acc -> acc @ (solve (merge c))) n [] in
List.fold_left (fun acc e -> (unify e)::acc) [] m List.fold_left (fun acc e -> (unify e)::acc) [] m
end end
...@@ -355,6 +355,8 @@ module Tallying : sig ...@@ -355,6 +355,8 @@ module Tallying : sig
|Pos of (Var.var * t) (** alpha <= t | alpha \in P *) |Pos of (Var.var * t) (** alpha <= t | alpha \in P *)
|Neg of (t * Var.var) (** t <= alpha | alpha \in N *) |Neg of (t * Var.var) (** t <= alpha | alpha \in N *)
exception UnSatConstr
module CS : sig module CS : sig
module M : sig module M : sig
include Map.S with type key = (bool * Var.var) include Map.S with type key = (bool * Var.var)
...@@ -386,7 +388,8 @@ module Tallying : sig ...@@ -386,7 +388,8 @@ module Tallying : sig
end end
val norm : t -> CS.s val norm : t -> CS.s
val merge : CS.m -> CS.e list val merge : CS.m -> CS.s
val solve : CS.s -> CS.e list
val unify : CS.e -> CS.e val unify : CS.e -> CS.e
val tallying : (t * t) list -> CS.e list val tallying : (t * t) list -> CS.e list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment