Commit 3e06d10f authored by Pietro Abate's avatar Pietro Abate
Browse files

Fix memoization for Tallying.tallying

- delta is now kept into consideration when memoizing
parent 0e6843bf
...@@ -333,7 +333,7 @@ let test_tallying = ...@@ -333,7 +333,7 @@ let test_tallying =
(print_test l) >:: (fun _ -> (print_test l) >:: (fun _ ->
try try
let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in
let sigma = Tallying.tallying l in let sigma = Tallying.tallying Var.Set.empty l in
List.iter (fun (s,t) -> List.iter (fun (s,t) ->
List.iter (fun sigma -> List.iter (fun sigma ->
let s_sigma = Tallying.(s $$ sigma) in let s_sigma = Tallying.(s $$ sigma) in
......
...@@ -2642,21 +2642,23 @@ module Positive = struct ...@@ -2642,21 +2642,23 @@ module Positive = struct
latter but the variance annotation tells us that A is invariant. latter but the variance annotation tells us that A is invariant.
*) *)
let rec pretty_name i acc =
let ni,nm = i/26, i mod 26 in
let acc = acc ^ (String.make 1 (OldChar.chr (OldChar.code 'a' + nm))) in
if ni == 0 then acc else pretty_name ni acc
let collect_variables delta v = let collect_variables delta v =
(* we memoize based on the pair (pos, v), since v can occur both (* we memoize based on the pair (pos, v), since v can occur both
* positively and negatively. and we want to manage the variables positively and negatively. and we want to manage the variables
* differently in both cases *) differently in both cases. We do not need to memoize on delta as
let module Memo = Hashtbl.Make (struct the memoization is local and delta does not change *)
let module Memo =
Hashtbl.Make (struct
type t = bool * v type t = bool * v
let hash = Hashtbl.hash let hash = Hashtbl.hash
let equal (a,b) (c,d) = a == c && b == d let equal (a,b) (c,d) = a == c && b == d
end) end)
in in
let rec pretty_name i acc =
let ni,nm = i/26, i mod 26 in
let acc = acc ^ (String.make 1 (OldChar.chr (OldChar.code 'a' + nm))) in
if ni == 0 then acc else pretty_name ni acc
in
let vars = Hashtbl.create 17 in let vars = Hashtbl.create 17 in
let memo = Memo.create 17 in let memo = Memo.create 17 in
let t_emp = cup [] in let t_emp = cup [] in
...@@ -3159,16 +3161,24 @@ module Tallying = struct ...@@ -3159,16 +3161,24 @@ module Tallying = struct
in in
big_prod delta norm_arrow CS.sat (Pair.get t) big_prod delta norm_arrow CS.sat (Pair.get t)
let memo_norm = DescrHash.create 17 module NormMemoHash = Hashtbl.Make(
struct
type t = (Descr.t * Var.Set.t)
let hash (v,d) = Descr.hash v + Var.Set.hash d
let equal (v1,d1) (v2,d2) = Descr.equal v1 v2 && Var.Set.equal d1 d2
end )
let memo_norm = NormMemoHash.create 17
(* XXX here I hash over a set . this might lead to a (* XXX here I hash over a set . this might lead to a
* conflict in the hash function if the set is too large *) * conflict in the hash function if the set is too large *)
let norm delta t = let norm delta t =
try DescrHash.find memo_norm t try NormMemoHash.find memo_norm (t,delta)
with Not_found -> begin with Not_found -> begin
let res = norm (t,delta,DescrSet.empty) in let res = norm (t,delta,DescrSet.empty) in
DescrHash.add memo_norm t res; res NormMemoHash.add memo_norm (t,delta) res; res
end end
(* merge needs delta because it calls norm recursively *)
let rec merge (m,delta,mem) = let rec merge (m,delta,mem) =
let res = let res =
CS.M.fold (fun v (inf, sup) acc -> CS.M.fold (fun v (inf, sup) acc ->
...@@ -3252,7 +3262,7 @@ module Tallying = struct ...@@ -3252,7 +3262,7 @@ module Tallying = struct
exception Step1Fail exception Step1Fail
exception Step2Fail exception Step2Fail
let tallying ?(delta=Var.Set.empty) l = let tallying delta l =
let n = let n =
List.fold_left (fun acc (s,t) -> List.fold_left (fun acc (s,t) ->
let d = diff s t in let d = diff s t in
...@@ -3357,12 +3367,12 @@ let get a i = if i < 0 then any else (!a).(i) ...@@ -3357,12 +3367,12 @@ let get a i = if i < 0 then any else (!a).(i)
exception FoundSquareSub of Tallying.CS.sl exception FoundSquareSub of Tallying.CS.sl
let squaresubtype delta s t = let squaresubtype delta s t =
DescrHash.clear Tallying.memo_norm; Tallying.NormMemoHash.clear Tallying.memo_norm;
let ai = ref [| |] in let ai = ref [| |] in
let tallying i = let tallying i =
try try
let s = get ai i in let s = get ai i in
let sl = Tallying.tallying ~delta [ (s,t) ] in let sl = Tallying.tallying delta [ (s,t) ] in
raise (FoundSquareSub sl) raise (FoundSquareSub sl)
with with
Tallying.Step1Fail -> (assert (i == 0); raise (Tallying.UnSatConstr "apply_raw step1")) Tallying.Step1Fail -> (assert (i == 0); raise (Tallying.UnSatConstr "apply_raw step1"))
...@@ -3389,7 +3399,7 @@ exception FoundApply of t * int * int * Tallying.CS.sl ...@@ -3389,7 +3399,7 @@ exception FoundApply of t * int * int * Tallying.CS.sl
(** find two sets of type substitutions I,J such that (** find two sets of type substitutions I,J such that
s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *) s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *)
let apply_raw delta s t = let apply_raw delta s t =
DescrHash.clear Tallying.memo_norm; Tallying.NormMemoHash.clear Tallying.memo_norm;
let gamma = var (Var.mk "Gamma") in let gamma = var (Var.mk "Gamma") in
let cgamma = cons gamma in let cgamma = cons gamma in
(* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *) (* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *)
...@@ -3399,15 +3409,10 @@ let apply_raw delta s t = ...@@ -3399,15 +3409,10 @@ let apply_raw delta s t =
try try
let s = get ai i in let s = get ai i in
let t = arrow (cons (get aj j)) cgamma in let t = arrow (cons (get aj j)) cgamma in
(* Format.printf "Tallying s=%a < t=%a\n" Print.pp_type s Print.pp_type t; *) let sl = Tallying.tallying delta [ (s,t) ] in
let sl = Tallying.tallying ~delta [ (s,t) ] in
let new_res = let new_res =
Positive.clean_type delta ( Positive.clean_type delta (
List.fold_left (fun tacc si -> List.fold_left (fun tacc si ->
(*
let a = (Tallying.(gamma $$ si)) in
let b = Positive.clean_type delta a in
Format.printf "dirty %a \n clean %a\n" Print.pp_type a Print.pp_type b; *)
cap tacc (Tallying.(gamma $$ si)) cap tacc (Tallying.(gamma $$ si))
) any sl ) any sl
) )
......
...@@ -432,7 +432,7 @@ module Tallying : sig ...@@ -432,7 +432,7 @@ module Tallying : sig
(* [s1 ... sn] . si is a solution for tallying problem (* [s1 ... sn] . si is a solution for tallying problem
if si # delta and for all (s,t) in C si @ s < si @ t *) if si # delta and for all (s,t) in C si @ s < si @ t *)
val tallying : ?delta : Var.Set.t -> (t * t) list -> CS.sl val tallying : Var.Set.t -> (t * t) list -> CS.sl
val ($$) : t -> CS.sigma -> t val ($$) : t -> CS.sigma -> t
......
...@@ -13,7 +13,7 @@ val fresh : ?pre:string -> unit -> var ...@@ -13,7 +13,7 @@ val fresh : ?pre:string -> unit -> var
val id : var -> string val id : var -> string
module Set : sig module Set : sig
type t include Custom.T
val dump : Format.formatter -> t -> unit val dump : Format.formatter -> t -> unit
val pp : Format.formatter -> t -> unit val pp : Format.formatter -> t -> unit
val printf : t -> unit val printf : t -> unit
......
...@@ -956,7 +956,7 @@ and type_check' loc env ed constr precise = match ed with ...@@ -956,7 +956,7 @@ and type_check' loc env ed constr precise = match ed with
(* t [_delta 0 -> 1 *) (* t [_delta 0 -> 1 *)
begin try begin try
ignore(Types.Tallying.tallying ~delta:env.delta [(t1,Types.Arrow.any)]) ignore(Types.Tallying.tallying env.delta [(t1,Types.Arrow.any)])
with Types.Tallying.UnSatConstr _ -> with Types.Tallying.UnSatConstr _ ->
raise_loc loc (Constraint (t1, Types.Arrow.any)) raise_loc loc (Constraint (t1, Types.Arrow.any))
end; end;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment