Commit 3e06d10f authored by Pietro Abate's avatar Pietro Abate

Fix memoization for Tallying.tallying

- delta is now kept into consideration when memoizing
parent 0e6843bf
......@@ -333,7 +333,7 @@ let test_tallying =
(print_test l) >:: (fun _ ->
try
let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in
let sigma = Tallying.tallying l in
let sigma = Tallying.tallying Var.Set.empty l in
List.iter (fun (s,t) ->
List.iter (fun sigma ->
let s_sigma = Tallying.(s $$ sigma) in
......
......@@ -2642,20 +2642,22 @@ module Positive = struct
latter but the variance annotation tells us that A is invariant.
*)
let rec pretty_name i acc =
let ni,nm = i/26, i mod 26 in
let acc = acc ^ (String.make 1 (OldChar.chr (OldChar.code 'a' + nm))) in
if ni == 0 then acc else pretty_name ni acc
let collect_variables delta v =
(* we memoize based on the pair (pos, v), since v can occur both
* positively and negatively. and we want to manage the variables
* differently in both cases *)
let module Memo = Hashtbl.Make (struct
type t = bool * v
let hash = Hashtbl.hash
let equal (a,b) (c,d) = a == c && b == d
end)
positively and negatively. and we want to manage the variables
differently in both cases. We do not need to memoize on delta as
the memoization is local and delta does not change *)
let module Memo =
Hashtbl.Make (struct
type t = bool * v
let hash = Hashtbl.hash
let equal (a,b) (c,d) = a == c && b == d
end)
in
let rec pretty_name i acc =
let ni,nm = i/26, i mod 26 in
let acc = acc ^ (String.make 1 (OldChar.chr (OldChar.code 'a' + nm))) in
if ni == 0 then acc else pretty_name ni acc
in
let vars = Hashtbl.create 17 in
let memo = Memo.create 17 in
......@@ -3159,16 +3161,24 @@ module Tallying = struct
in
big_prod delta norm_arrow CS.sat (Pair.get t)
let memo_norm = DescrHash.create 17
module NormMemoHash = Hashtbl.Make(
struct
type t = (Descr.t * Var.Set.t)
let hash (v,d) = Descr.hash v + Var.Set.hash d
let equal (v1,d1) (v2,d2) = Descr.equal v1 v2 && Var.Set.equal d1 d2
end )
let memo_norm = NormMemoHash.create 17
(* XXX here I hash over a set . this might lead to a
* conflict in the hash function if the set is too large *)
let norm delta t =
try DescrHash.find memo_norm t
try NormMemoHash.find memo_norm (t,delta)
with Not_found -> begin
let res = norm (t,delta,DescrSet.empty) in
DescrHash.add memo_norm t res; res
NormMemoHash.add memo_norm (t,delta) res; res
end
(* merge needs delta because it calls norm recursively *)
let rec merge (m,delta,mem) =
let res =
CS.M.fold (fun v (inf, sup) acc ->
......@@ -3252,7 +3262,7 @@ module Tallying = struct
exception Step1Fail
exception Step2Fail
let tallying ?(delta=Var.Set.empty) l =
let tallying delta l =
let n =
List.fold_left (fun acc (s,t) ->
let d = diff s t in
......@@ -3357,12 +3367,12 @@ let get a i = if i < 0 then any else (!a).(i)
exception FoundSquareSub of Tallying.CS.sl
let squaresubtype delta s t =
DescrHash.clear Tallying.memo_norm;
Tallying.NormMemoHash.clear Tallying.memo_norm;
let ai = ref [| |] in
let tallying i =
try
let s = get ai i in
let sl = Tallying.tallying ~delta [ (s,t) ] in
let sl = Tallying.tallying delta [ (s,t) ] in
raise (FoundSquareSub sl)
with
Tallying.Step1Fail -> (assert (i == 0); raise (Tallying.UnSatConstr "apply_raw step1"))
......@@ -3389,7 +3399,7 @@ exception FoundApply of t * int * int * Tallying.CS.sl
(** find two sets of type substitutions I,J such that
s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *)
let apply_raw delta s t =
DescrHash.clear Tallying.memo_norm;
Tallying.NormMemoHash.clear Tallying.memo_norm;
let gamma = var (Var.mk "Gamma") in
let cgamma = cons gamma in
(* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *)
......@@ -3399,15 +3409,10 @@ let apply_raw delta s t =
try
let s = get ai i in
let t = arrow (cons (get aj j)) cgamma in
(* Format.printf "Tallying s=%a < t=%a\n" Print.pp_type s Print.pp_type t; *)
let sl = Tallying.tallying ~delta [ (s,t) ] in
let sl = Tallying.tallying delta [ (s,t) ] in
let new_res =
Positive.clean_type delta (
List.fold_left (fun tacc si ->
(*
let a = (Tallying.(gamma $$ si)) in
let b = Positive.clean_type delta a in
Format.printf "dirty %a \n clean %a\n" Print.pp_type a Print.pp_type b; *)
cap tacc (Tallying.(gamma $$ si))
) any sl
)
......
......@@ -432,7 +432,7 @@ module Tallying : sig
(* [s1 ... sn] . si is a solution for tallying problem
if si # delta and for all (s,t) in C si @ s < si @ t *)
val tallying : ?delta : Var.Set.t -> (t * t) list -> CS.sl
val tallying : Var.Set.t -> (t * t) list -> CS.sl
val ($$) : t -> CS.sigma -> t
......
......@@ -13,7 +13,7 @@ val fresh : ?pre:string -> unit -> var
val id : var -> string
module Set : sig
type t
include Custom.T
val dump : Format.formatter -> t -> unit
val pp : Format.formatter -> t -> unit
val printf : t -> unit
......
......@@ -956,7 +956,7 @@ and type_check' loc env ed constr precise = match ed with
(* t [_delta 0 -> 1 *)
begin try
ignore(Types.Tallying.tallying ~delta:env.delta [(t1,Types.Arrow.any)])
ignore(Types.Tallying.tallying env.delta [(t1,Types.Arrow.any)])
with Types.Tallying.UnSatConstr _ ->
raise_loc loc (Constraint (t1, Types.Arrow.any))
end;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment