Commit 8b9f098a authored by Pietro Abate's avatar Pietro Abate
Browse files

More unit test. Tallying.tallying returns a set, not a list

parent 11c9e701
......@@ -32,6 +32,26 @@ let norm_tests = [
(mk_pos ("A","Int"))
(mk_neg ("Bool","B"))
);
"(Int -> Int) | (Bool -> Bool)", "(`$A -> `$B)",
cup
(mk_pos ("A","Empty"))
(cup
(cup
(cap
(mk_pos ("A","Empty"))
(mk_neg ("Int","B"))
)
(cap
(mk_pos ("A","Empty"))
(mk_neg ("Bool","B"))
)
)
(cap
(mk_pos ("A","Empty"))
(mk_neg ("Int | Bool","B"))
)
);
]
let m_compare = Tallying.CS.M.compare Types.compare
......@@ -115,7 +135,7 @@ module EList = OUnitDiff.ListSimpleMake (struct
let pp_print_sep = OUnitDiff.pp_comma_separator
end)
module SubList = OUnitDiff.ListSimpleMake (struct
module SubSet = OUnitDiff.SetMake (struct
type t = EList.t
let compare a b = EList.compare a b
let pp_printer ppf l = EList.pp_printer ppf l
......@@ -131,20 +151,25 @@ let tallying_tests = [
[("((Int | Bool) -> Int)", "(`$A -> `$B)"); ("(`$A -> Bool)","(`$B -> `$B)")], mk_e [
[("A","Empty");("B","Empty")];
[("A","Int | Bool");("B","Int | Bool")]
]
];
[("(Int -> Int) | (Bool -> Bool)", "(`$A -> `$B)")], mk_e [
[("A","Empty")];
[("A","Empty");("B","Empty")];
];
[("((Int,Int) , (Int | Bool))","(`$A,Int) | ((`$B,Int),Bool)")], mk_e [[("A", "(Int,Int)"); ("B","Int")]];
]
let test_tallying =
let print_test l =
String.concat ";" (List.map (fun (s,t) -> Printf.sprintf " %s \\ %s" s t) l)
in
"test tallying merge" >:::
"test tallying" >:::
List.map (fun (l,expected) ->
(print_test l) >:: (fun _ ->
let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in
let result = Tallying.tallying l in
let elem s = SubList.of_list (List.map (fun l -> EList.of_list (List.sort compare_constr l)) s) in
SubList.assert_equal (elem expected) (elem result)
let elem s = SubSet.of_list (List.map (fun l -> EList.of_list (List.sort compare_constr l)) s) in
SubSet.assert_equal (elem expected) (elem result)
)
) tallying_tests
;;
......
......@@ -2291,7 +2291,8 @@ module Tallying = struct
end
(* equation set : (s < alpha < t) stored as ( alpha -> (s,t) ) *)
(* equation set : (s < alpha < t) stored as
* { alpha -> ((s v beta) ^ t) } with beta fresh *)
module E = struct
include Map.Make(struct
type t = Var.var
......@@ -2305,6 +2306,16 @@ module Tallying = struct
end
(* Set of equation sets *)
module ES = struct
include Set.Make(struct
type t = Descr.s E.t
let compare = E.compare Descr.compare
end)
let print ppf s = print_lst E.print ppf (elements s)
end
(* Set of constraint sets *)
module S = struct
include Set.Make(struct
......@@ -2319,6 +2330,7 @@ module Tallying = struct
type s = S.t
type m = Descr.s M.t
type e = Descr.s E.t
type es = ES.t
let singleton = function
|Pos (v,s) -> S.singleton (M.singleton (true,v) s)
......@@ -2342,6 +2354,9 @@ module Tallying = struct
let sat = S.singleton M.empty
let unsat = S.empty
let cup = S.union
(* cartesian product of two sets of contraints sets where each
* resulting constraint set is than merged *)
let cap x y =
match S.is_empty x,S.is_empty y with
|true,true -> S.empty
......@@ -2492,7 +2507,7 @@ module Tallying = struct
List.fold_left (fun acc (_,p,n) -> CS.cap acc (norm_rec (p,n))) CS.S.empty (get_record t)
(* arrow(p,{t1 -> t2}) = [t1] ^ arrow'(t1,any \\ t2,p)
* arrow'(t1,acc,p) =
* arrow'(t1,acc,{s1 -> s2} v p) =
([t1\s1] ^ arrow'(t1\s1,acc,p)) v
([acc ^ {s2} \ t2] ^ arrow'(t1,acc ^ {s2},p))
......@@ -2546,28 +2561,44 @@ module Tallying = struct
in
if CS.S.is_empty mm then CS.S.singleton m else mm
(* returns a constraint set or UnSatConstr *)
let merge m = merge (m,DescrSet.empty)
let solve s =
let aux v (s,t) acc =
let b = var (Var.fresh ()) in
CS.E.add v (cap (cup s b) t) acc
if CS.E.mem v acc then assert false else
if equal s empty && equal t any then
let b = var (Var.fresh ()) in
CS.E.add v b acc
else if equal t empty then CS.E.add v empty acc
else if equal s any then CS.E.add v t acc
else
let b = var (Var.fresh ()) in
CS.E.add v (cap (cup s b) t) acc
in
let aux m =
let cache = Hashtbl.create (CS.M.cardinal m) in
CS.M.fold (fun (b,v) s acc ->
try
let t = CS.M.find (not b,v) m in
if t.toplvars.b && (Var.Set.cardinal t.toplvars.s) = 1 then begin
let z = Var.Set.max_elt t.toplvars.s in
let acc1 = if b then aux v (empty,t) acc else aux v (t,any) acc in
if b then aux z (empty,any) acc else aux v (empty,any) acc1
end else
if b then aux v (t,s) acc else aux v (s,t) acc
with Not_found ->
if b then aux v (any,s) acc else aux v (s,empty) acc
if Hashtbl.mem cache v then acc else begin
Hashtbl.add cache v ();
try
let t = CS.M.find (not b,v) m in
(* if t containts only a toplevel variable and nothing else *)
if t.toplvars.b && (Var.Set.cardinal t.toplvars.s) = 1 then begin
if b then
let z = Var.Set.max_elt t.toplvars.s in
aux z (empty,any) acc
else
let acc1 = if b then aux v (empty,t) acc else aux v (t,any) acc in
aux v (empty,any) acc1
end else
if b then aux v (t,s) acc else aux v (s,t) acc
with Not_found ->
if b then aux v (any,s) acc else aux v (s,empty) acc
end
) m CS.E.empty
in
CS.S.fold (fun m acc -> (aux m)::acc) s []
CS.S.fold (fun m acc -> CS.ES.add (aux m) acc) s CS.ES.empty
let unify e =
let rec aux (sol,acc) e =
......@@ -2578,13 +2609,16 @@ module Tallying = struct
(* XXX ... let x = Var.fresh () in *)
(* replace in e1 all occurrences of a by ... *)
let es = CS.E.fold (fun beta s acc -> CS.E.add beta (subst s (alpha,t)) acc) e1 CS.E.empty in
aux (((alpha,t)::sol),(CS.E.add alpha (subst t (alpha,t)) acc)) es
aux ((CS.E.add alpha t sol),(CS.E.add alpha (subst t (alpha,t)) acc)) es
in
aux ([],CS.E.empty) e
aux (CS.E.empty,CS.E.empty) e
let tallying l =
let n = List.fold_left (fun acc (s,t) -> CS.cap acc (norm(diff s t))) CS.S.empty l in
let m = CS.S.fold (fun c acc -> try acc @ (solve (merge c)) with UnSatConstr -> acc) n [] in
List.fold_left (fun acc e -> (unify e)::acc) [] m
let n = List.fold_left (fun acc (s,t) -> try CS.cap acc (norm(diff s t)) with UnSatConstr -> acc) CS.S.empty l in
if CS.S.is_empty n then raise UnSatConstr else
let m = CS.S.fold (fun c acc -> try CS.ES.union (solve (merge c)) acc with UnSatConstr -> acc) n CS.ES.empty in
if CS.ES.is_empty m then raise UnSatConstr else
let el = CS.ES.fold (fun e acc -> CS.ES.add (unify e) acc) m CS.ES.empty in
List.map (CS.E.bindings) (CS.ES.elements el)
end
......@@ -366,6 +366,10 @@ module Tallying : sig
include Map.S with type key = Var.var
val print : Format.formatter -> descr t -> unit
end
module ES : sig
include Set.S with type elt = descr E.t
val print : Format.formatter -> t -> unit
end
module S : sig
include Set.S with type elt = descr M.t
val print : Format.formatter -> t -> unit
......@@ -374,6 +378,7 @@ module Tallying : sig
type s = S.t
type m = t M.t
type e = t E.t
type es = ES.t
val print : Format.formatter -> s -> unit
val print_m : Format.formatter -> m -> unit
......@@ -389,8 +394,8 @@ module Tallying : sig
val norm : t -> CS.s
val merge : CS.m -> CS.s
val solve : CS.s -> CS.e list
val unify : CS.e -> (Var.var * t) list
val solve : CS.s -> CS.es
val unify : CS.e -> CS.e
val tallying : (t * t) list -> (Var.var * t) list list
end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment