Commit 3e06d10f by Pietro Abate

### Fix memoization for Tallying.tallying

- delta is now kept into consideration when memoizing
parent 0e6843bf
 ... @@ -333,7 +333,7 @@ let test_tallying = ... @@ -333,7 +333,7 @@ let test_tallying = (print_test l) >:: (fun _ -> (print_test l) >:: (fun _ -> try try let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in let l = List.map (fun (s,t) -> (parse_typ s,parse_typ t)) l in let sigma = Tallying.tallying l in let sigma = Tallying.tallying Var.Set.empty l in List.iter (fun (s,t) -> List.iter (fun (s,t) -> List.iter (fun sigma -> List.iter (fun sigma -> let s_sigma = Tallying.(s $$sigma) in let s_sigma = Tallying.(s$$ sigma) in ... ...
 ... @@ -2642,20 +2642,22 @@ module Positive = struct ... @@ -2642,20 +2642,22 @@ module Positive = struct latter but the variance annotation tells us that A is invariant. latter but the variance annotation tells us that A is invariant. *) *) let rec pretty_name i acc = let ni,nm = i/26, i mod 26 in let acc = acc ^ (String.make 1 (OldChar.chr (OldChar.code 'a' + nm))) in if ni == 0 then acc else pretty_name ni acc let collect_variables delta v = let collect_variables delta v = (* we memoize based on the pair (pos, v), since v can occur both (* we memoize based on the pair (pos, v), since v can occur both * positively and negatively. and we want to manage the variables positively and negatively. and we want to manage the variables * differently in both cases *) differently in both cases. We do not need to memoize on delta as let module Memo = Hashtbl.Make (struct the memoization is local and delta does not change *) type t = bool * v let module Memo = let hash = Hashtbl.hash Hashtbl.Make (struct let equal (a,b) (c,d) = a == c && b == d type t = bool * v end) let hash = Hashtbl.hash let equal (a,b) (c,d) = a == c && b == d end) in let rec pretty_name i acc = let ni,nm = i/26, i mod 26 in let acc = acc ^ (String.make 1 (OldChar.chr (OldChar.code 'a' + nm))) in if ni == 0 then acc else pretty_name ni acc in in let vars = Hashtbl.create 17 in let vars = Hashtbl.create 17 in let memo = Memo.create 17 in let memo = Memo.create 17 in ... @@ -3159,16 +3161,24 @@ module Tallying = struct ... @@ -3159,16 +3161,24 @@ module Tallying = struct in in big_prod delta norm_arrow CS.sat (Pair.get t) big_prod delta norm_arrow CS.sat (Pair.get t) let memo_norm = DescrHash.create 17 module NormMemoHash = Hashtbl.Make( struct type t = (Descr.t * Var.Set.t) let hash (v,d) = Descr.hash v + Var.Set.hash d let equal (v1,d1) (v2,d2) = Descr.equal v1 v2 && Var.Set.equal d1 d2 end ) let memo_norm = NormMemoHash.create 17 (* XXX here I hash over a set . this might lead to a (* XXX here I hash over a set . this might lead to a * conflict in the hash function if the set is too large *) * conflict in the hash function if the set is too large *) let norm delta t = let norm delta t = try DescrHash.find memo_norm t try NormMemoHash.find memo_norm (t,delta) with Not_found -> begin with Not_found -> begin let res = norm (t,delta,DescrSet.empty) in let res = norm (t,delta,DescrSet.empty) in DescrHash.add memo_norm t res; res NormMemoHash.add memo_norm (t,delta) res; res end end (* merge needs delta because it calls norm recursively *) let rec merge (m,delta,mem) = let rec merge (m,delta,mem) = let res = let res = CS.M.fold (fun v (inf, sup) acc -> CS.M.fold (fun v (inf, sup) acc -> ... @@ -3252,7 +3262,7 @@ module Tallying = struct ... @@ -3252,7 +3262,7 @@ module Tallying = struct exception Step1Fail exception Step1Fail exception Step2Fail exception Step2Fail let tallying ?(delta=Var.Set.empty) l = let tallying delta l = let n = let n = List.fold_left (fun acc (s,t) -> List.fold_left (fun acc (s,t) -> let d = diff s t in let d = diff s t in ... @@ -3357,12 +3367,12 @@ let get a i = if i < 0 then any else (!a).(i) ... @@ -3357,12 +3367,12 @@ let get a i = if i < 0 then any else (!a).(i) exception FoundSquareSub of Tallying.CS.sl exception FoundSquareSub of Tallying.CS.sl let squaresubtype delta s t = let squaresubtype delta s t = DescrHash.clear Tallying.memo_norm; Tallying.NormMemoHash.clear Tallying.memo_norm; let ai = ref [| |] in let ai = ref [| |] in let tallying i = let tallying i = try try let s = get ai i in let s = get ai i in let sl = Tallying.tallying ~delta [ (s,t) ] in let sl = Tallying.tallying delta [ (s,t) ] in raise (FoundSquareSub sl) raise (FoundSquareSub sl) with with Tallying.Step1Fail -> (assert (i == 0); raise (Tallying.UnSatConstr "apply_raw step1")) Tallying.Step1Fail -> (assert (i == 0); raise (Tallying.UnSatConstr "apply_raw step1")) ... @@ -3389,7 +3399,7 @@ exception FoundApply of t * int * int * Tallying.CS.sl ... @@ -3389,7 +3399,7 @@ exception FoundApply of t * int * int * Tallying.CS.sl (** find two sets of type substitutions I,J such that (** find two sets of type substitutions I,J such that s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *) s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *) let apply_raw delta s t = let apply_raw delta s t = DescrHash.clear Tallying.memo_norm; Tallying.NormMemoHash.clear Tallying.memo_norm; let gamma = var (Var.mk "Gamma") in let gamma = var (Var.mk "Gamma") in let cgamma = cons gamma in let cgamma = cons gamma in (* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *) (* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *) ... @@ -3399,15 +3409,10 @@ let apply_raw delta s t = ... @@ -3399,15 +3409,10 @@ let apply_raw delta s t = try try let s = get ai i in let s = get ai i in let t = arrow (cons (get aj j)) cgamma in let t = arrow (cons (get aj j)) cgamma in (* Format.printf "Tallying s=%a < t=%a\n" Print.pp_type s Print.pp_type t; *) let sl = Tallying.tallying delta [ (s,t) ] in let sl = Tallying.tallying ~delta [ (s,t) ] in let new_res = let new_res = Positive.clean_type delta ( Positive.clean_type delta ( List.fold_left (fun tacc si -> List.fold_left (fun tacc si -> (* let a = (Tallying.(gamma $$si)) in let b = Positive.clean_type delta a in Format.printf "dirty %a \n clean %a\n" Print.pp_type a Print.pp_type b; *) cap tacc (Tallying.(gamma$$ si)) cap tacc (Tallying.(gamma $$si)) ) any sl ) any sl ) ) ... ...  ... @@ -432,7 +432,7 @@ module Tallying : sig ... @@ -432,7 +432,7 @@ module Tallying : sig (* [s1 ... sn] . si is a solution for tallying problem (* [s1 ... sn] . si is a solution for tallying problem if si # delta and for all (s,t) in C si @ s < si @ t *) if si # delta and for all (s,t) in C si @ s < si @ t *) val tallying : ?delta : Var.Set.t -> (t * t) list -> CS.sl val tallying : Var.Set.t -> (t * t) list -> CS.sl val ($$) : t -> CS.sigma -> t val () : t -> CS.sigma -> t ... ...
 ... @@ -13,7 +13,7 @@ val fresh : ?pre:string -> unit -> var ... @@ -13,7 +13,7 @@ val fresh : ?pre:string -> unit -> var val id : var -> string val id : var -> string module Set : sig module Set : sig type t include Custom.T val dump : Format.formatter -> t -> unit val dump : Format.formatter -> t -> unit val pp : Format.formatter -> t -> unit val pp : Format.formatter -> t -> unit val printf : t -> unit val printf : t -> unit ... ...
 ... @@ -956,7 +956,7 @@ and type_check' loc env ed constr precise = match ed with ... @@ -956,7 +956,7 @@ and type_check' loc env ed constr precise = match ed with (* t [_delta 0 -> 1 *) (* t [_delta 0 -> 1 *) begin try begin try ignore(Types.Tallying.tallying ~delta:env.delta [(t1,Types.Arrow.any)]) ignore(Types.Tallying.tallying env.delta [(t1,Types.Arrow.any)]) with Types.Tallying.UnSatConstr _ -> with Types.Tallying.UnSatConstr _ -> raise_loc loc (Constraint (t1, Types.Arrow.any)) raise_loc loc (Constraint (t1, Types.Arrow.any)) end; end; ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment