Commit a8ba6ab6 authored by Kim Nguyễn's avatar Kim Nguyễn
Browse files

Implement a cache for the positive equation solver (partial results are stored in nodes)

parent 083edd02
......@@ -41,7 +41,7 @@ let balance ( Unbalanced('a) -> Rtree('a) ; 'b & RBtree('a) -> 'b & RBtree('a) )
| x -> x
;;
let [] = []
(* *)
(* Version 2: restrict the first branch to Unbalanced trees whatever *)
(* type it contains *)
......@@ -183,7 +183,7 @@ let cardinal ( RBtree('a) -> Int ) (* better type: [] -> 0, Any\[] -> [1--*] *
| <(c) elem=e>[ l r ] ->
(<black elem=e>[ l (balance (redify r)) ], (c = `black))
(*
let remove(x : 'a)(t : RBtree('a) ) : RBtree('a) =
let remove_aux(RBtree('a) -> (RBtree('a),Bool) )
| [] ->
......@@ -207,3 +207,4 @@ let remove(x : 'a)(t : RBtree('a) ) : RBtree('a) =
if d then bubble_left tree else (tree, `false)
in
let (sol,_) = remove_aux t in sol
*)
......@@ -64,6 +64,7 @@ sig
val mapi: (Elem.t -> 'a -> 'b) -> 'a map -> 'b map
val constant: 'a -> t -> 'a map
val num: int -> t -> int map
val init : (Elem.t -> 'a) -> t -> 'a map
val map_to_list: ('a -> 'b) -> 'a map -> 'b list
val mapi_to_list: (Elem.t -> 'a -> 'b) -> 'a map -> 'b list
val assoc: Elem.t -> 'a map -> 'a
......@@ -410,6 +411,10 @@ module Make(X : Custom.T) = struct
| (x,y)::l -> (f y)::(map_to_list f l)
| [] -> []
let rec init f = function
[] -> []
| x :: l -> (x, f x) :: (init f l)
let rec assoc v = function
| (x,y)::l ->
let c = Elem.compare x v in
......
......@@ -64,6 +64,7 @@ sig
val mapi: (Elem.t -> 'a -> 'b) -> 'a map -> 'b map
val constant: 'a -> t -> 'a map
val num: int -> t -> int map
val init : (Elem.t -> 'a) -> t -> 'a map
val map_to_list: ('a -> 'b) -> 'a map -> 'b list
val mapi_to_list: (Elem.t -> 'a -> 'b) -> 'a map -> 'b list
val assoc: Elem.t -> 'a map -> 'a
......
......@@ -2636,7 +2636,8 @@ module Positive = struct
|`Xml of v * v
|`Record of bool * (bool * Ns.Label.t * v) list
]
and v = { mutable def : rhs; mutable node : node option; }
and v = { mutable def : rhs; mutable node : node option;
mutable descr : Descr.t option}
module MemoHash = Hashtbl.Make( struct
type t = v
......@@ -2670,8 +2671,10 @@ module Positive = struct
aux ppf v
let printf = pp Format.std_formatter
let rec make_descr seen v =
match v.descr with
| Some d -> d
| None ->
if List.memq v seen then empty
else
let seen = v :: seen in
......@@ -2697,13 +2700,13 @@ module Positive = struct
n
(* We shadow the corresponding definitions in the outer module *)
let forward () = { def = `Cup []; node = None; }
let forward () = { def = `Cup []; node = None; descr = None }
let def v d = v.def <- d
let cons d = let v = forward () in def v d; v
let ty d = cons (`Type d)
let var d = cons (`Variable d)
let neg v = cons (`Neg v)
let rec cup vl = cons (`Cup vl)
let cup vl = cons (`Cup vl)
let cap vl = cons (`Cap vl)
let times v1 v2 = cons (`Times (v1,v2))
let arrow v1 v2 = cons (`Arrow (v1,v2))
......@@ -2762,6 +2765,7 @@ module Positive = struct
and decompose_type t =
try DescrHash.find memo t
with Not_found ->
let r =
if no_var t then ty t
else
match check_var t with
......@@ -2782,135 +2786,152 @@ module Positive = struct
in
node_t.def <- (cup descr_t).def; node_t
in
decompose_type t
let solve v = internalize (make_node v)
(* [map_var f v] applies returns the type
[v{ 'a <- f 'a}] for all ['a] in [v]
*)
let map_var subst v =
let memo = MemoHash.create 17 in
let rec aux v subst =
try MemoHash.find memo v
with Not_found ->
let node_v = forward () in
let () = MemoHash.add memo v node_v in
let new_v =
match v.def with
|`Type d -> `Type d
(* |`Variable d when Var.Set.mem d delta -> v.def *)
|`Variable d -> (subst d).def
|`Cup vl -> `Cup (List.map (fun v -> aux v subst) vl)
|`Cap vl -> `Cap (List.map (fun v -> aux v subst) vl)
|`Times (v1,v2) -> `Times (aux v1 subst, aux v2 subst)
|`Arrow (v1,v2) -> `Arrow (aux v1 subst, aux v2 subst)
|`Xml (v1,v2) -> `Xml (aux v1 subst, aux v2 subst)
|`Record (b, flst) ->
`Record (b, List.map (fun (b,l,v) -> (b,l,aux v subst)) flst)
|`Neg v -> `Neg (aux v subst)
in
node_v.def <- new_v;
node_v
r.descr <- Some t;
r
in
aux v subst
decompose_type t
let apply_subst ?(subst=(fun v -> var v)) ?(after=(fun x -> x)) t =
if no_var t then t else
let res = map_var subst (decompose t) in
let res = after res in
descr (solve res)
let solve v = (*match v.descr with
None -> *)internalize (make_node v)
(*| Some t -> T.cons t *)
(* Given a type t and a polymorphic variable 'a occuring in t,
returns the type s which is the solution of 'a = t *)
let solve_rectype t alpha =
let x = forward () in
let subst d = if Var.equal d alpha then x else var d in
apply_subst ~subst:subst ~after:(fun y -> define x y;x) t
end
(* Pre-condition : alpha \not\in \delta *)
module MemoSubst =
struct
include Hashtbl.Make (struct
type t = descr * (Var.t * descr) list
module Substitution =
struct
module Map = Var.Set.Map
type t = Descr.t Map.map
type order = int Map.map
let identity = Map.empty
let add v t m =
if is_var t && Var.(equal v (Set.choose (all_vars t))) then m
else Map.add v t m
let of_list l =
List.fold_left (fun acc (v, t) -> add v t acc) identity l
module Memo = Hashtbl.Make
(struct
type subst = t
type t = Descr.t * subst
let equal ((t1, l1) as k1) ((t2, l2) as k2) =
k1 == k2
|| ((t1 == t2 || Descr.equal t1 t2)
&& (l1 == l2 || Map.equal Descr.equal l1 l2))
let hash (t, l) =
List.fold_left
(fun acc (v,t) -> Var.hash v + 17 * Descr.hash t + 31 * acc)
(Descr.hash t) l
let equal (t1, l1) (t2, l2) =
Descr.equal t1 t2 && (try List.for_all2 (fun (v1, t1) (v2, t2) ->
Var.equal v1 v2 && Descr.equal t1 t2) l1 l2 with _ -> false)
(Descr.hash t + 31 * Map.hash Descr.hash l) land 0x3fff_ffff
end)
let global_memo = Memo.create 17
end
let memo_subst = MemoSubst.create 17
let rec substitute_list t l =
if no_var t || l == [] then t else
let k = (t,l) in
let rec apply_subst ?(after = (fun x -> x)) ?(do_var= fun x -> Positive.ty x) t subst =
let open Positive in
if subst == identity then descr (solve t) else
let memo = MemoHash.create 17 in
let todo = ref [] in
let rec aux v =
let found, update, v =
match v.descr with
| None -> false, None, v
| Some d ->
let vars = all_vars d in
if Var.Set.is_empty vars then true, None, ty d
else
let subst' = Map.restrict subst vars in
let key = (d, subst') in
try
MemoSubst.find memo_subst k
let d = Memo.find global_memo key in
true, None , ty d
with
Not_found ->
false, Some (key), v
in
if found then v else
let res =
try MemoHash.find memo v
with Not_found ->
let r =
let subst d =
try
ty
@@ snd
@@ List.find (fun (alpha,_) -> Var.equal d alpha) l
with Not_found -> var d
match v.def with
|`Variable d ->
let res =
(try
do_var (Map.assoc d subst)
with Not_found -> { forward () with def = v.def })
in
apply_subst ~subst:subst t
MemoHash.add memo v res; res
| x ->
let node_v = forward () in
let () = MemoHash.add memo v node_v in
let res =
match x with
| `Type _ -> x
| `Cup vl -> `Cup (List.map (fun v -> aux v) vl)
| `Cap vl -> `Cap (List.map (fun v -> aux v) vl)
| `Times (v1,v2) -> `Times (aux v1, aux v2)
| `Arrow (v1,v2) -> `Arrow (aux v1, aux v2)
| `Xml (v1,v2) -> `Xml (aux v1, aux v2)
| `Record (b, flst) ->
`Record (b, List.map (fun (b,l,v) -> (b,l,aux v)) flst)
| `Neg v -> `Neg (aux v)
| `Variable _ -> assert false
in
node_v.def <- res;
node_v
in
let () =
match update with
None -> ()
| Some key -> todo := (key, res) :: !todo
in
res
in
let res = aux t in
let res = after res in
let tres = descr (solve res) in
List.iter (fun ((d, subst) as key, res) ->
match res.node with
Some t -> begin
try
let (cu, name, subst) = DescrMap.find t !Print.named in
let _nsubst =
List.map (fun (v, vt) -> v, substitute_list vt l) subst
let (cu, name, al) =
DescrMap.find d !Print.named
in
Print.register_global (cu, name, _nsubst) r;
with Not_found -> ()
let nal = List.map (fun (v,t) -> v,apply_subst ~do_var ~after (Positive.decompose t) subst) al
in
MemoSubst.add memo_subst k r;
r
Print.register_global (cu, name, nal) d
with Not_found -> () end;
Memo.add global_memo key (descr t)
| _ -> () ) !todo;
tres
let apply t l =
if no_var t then t else
apply_subst (Positive.decompose t) (of_list l)
let substitute t s = substitute_list t [s]
let apply_single t s = apply t [s]
let substitute_free delta t =
let h = Hashtbl.create 17 in
let subst d =
if Var.Set.mem delta d then var d else
try
Hashtbl.find h d
with Not_found ->
let x = var (Var.fresh d) in
Hashtbl.add h d x ;
x
in
apply_subst ~subst:subst t
let refresh_type delta t =
if no_var t then t else
let vars = Var.Set.diff (all_vars t) delta in
let subst = Map.init (fun v -> var (Var.fresh v)) vars in
apply_subst (Positive.decompose t) subst
let substitute_kind delta kind t =
let subst d =
if Var.Set.mem delta d then var d else
var (Var.set_kind kind d)
in
apply_subst ~subst:subst t
if no_var t then t else
let vars = Var.Set.diff (all_vars t) delta in
let subst = Map.init (fun v -> var (Var.set_kind kind v)) vars in
apply_subst (Positive.decompose t) subst
(* We cannot use the variance annotation of variables to simplify them,
since variables are shared amongst types. If we have two types
A -> A and (A,A) (produced by the algorithm) then we can still simplify the
latter but the variance annotation tells us that A is invariant. *)
let collect_variables delta v =
let open Positive in
(* we memoize based on the pair (pos, v), since v can occur both
positively and negatively. and we want to manage the variables
differently in both cases. We do not need to memoize on delta as
the memoization is local and delta does not change *)
let module Memo =
Hashtbl.Make (struct
type t = bool * v
type t = bool * Positive.v
let hash = Hashtbl.hash
let equal (a,b) (c,d) = a == c && b == d
end)
......@@ -2926,8 +2947,8 @@ module Positive = struct
in
let vars = Hashtbl.create 17 in
let memo = Memo.create 17 in
let t_emp = cup [] in
let t_any = cap [] in
let t_emp = Positive.cup [] in
let t_any = Positive.ty any in
let idx = ref 0 in
let is_internal x =
let s = Var.id x in
......@@ -2936,7 +2957,7 @@ module Positive = struct
let rec aux pos v =
if not (Memo.mem memo (pos, v)) then
let () = Memo.add memo (pos,v) () in
match v.def with
match v.Positive.def with
|`Type d -> ()
|`Variable d when Var.Set.mem delta d || (not (is_internal d) && not pos) ->
Hashtbl.replace vars d v
......@@ -2959,24 +2980,26 @@ module Positive = struct
vars
let clean_type delta t =
if no_var t then t
else begin
let dec = decompose t in
let h = collect_variables delta dec in
let new_t =
map_var (fun d ->
try
Hashtbl.find h d
with Not_found -> assert false
) dec
in
descr (solve new_t)
end
let dump ppf t = pp ppf (decompose t)
if no_var t then t else
let vars = Var.Set.diff (all_vars t) delta in
if Var.Set.is_empty vars then t else
let v = Positive.decompose t in
let var_map = collect_variables delta v in
let sub : t =
Map.init (fun v -> descr (Positive.solve (Hashtbl.find var_map v))) vars
in
apply_subst v sub
let solve_fixpoint t v =
let subst = of_list [ (v, Descr.empty) ] in
let x = Positive.forward () in
let do_var _ = x in
let after t = Positive.define x t; t in
apply_subst ~after ~do_var (Positive.decompose t) subst
end
module Tallying = struct
type constr =
......@@ -3513,16 +3536,16 @@ module Tallying = struct
(* remove from E \ { (alpha,t) } every occurrences of alpha
* by mu X . (t{X/alpha}) with X fresh . X is a recursion variale *)
(* solve_rectype remove also all previously introduced fresh variables *)
let x = Positive.solve_rectype t alpha in
let x = Substitution.solve_fixpoint t alpha in
(* Format.printf "X = %a %a %a\n" Var.pp alpha Print.print x dump t; *)
let es =
CS.E.fold (fun beta s acc ->
CS.E.add beta (Positive.substitute s (alpha,x)) acc
CS.E.add beta (Substitution.apply_single s (alpha,x)) acc
) e1 CS.E.empty
in
(* Format.printf "es = %a\n" CS.print_e es; *)
let sigma = aux ((CS.E.add alpha x sol)) es in
let talpha = CS.E.fold (fun v sub acc -> Positive.substitute acc (v,sub)) sigma x in
let talpha = CS.E.fold (fun v sub acc -> Substitution.apply_single acc (v,sub)) sigma x in
CS.E.add alpha talpha sigma
end
in
......@@ -3557,7 +3580,7 @@ module Tallying = struct
(CS.ES.elements el)
(* apply sigma to t *)
let (>>) t si = CS.E.fold (fun v sub acc -> Positive.substitute acc (v,sub)) si t
let (>>) t si = CS.E.fold (fun v sub acc -> Substitution.apply_single acc (v,sub)) si t
type symsubst = I | S of CS.sigma | A of (symsubst * symsubst)
......@@ -3655,7 +3678,7 @@ let squaresubtype delta s t =
try
let ss =
if i = 0 then s
else (cap (Positive.substitute_free delta s) (get ai (i-1)))
else (cap (Substitution.refresh_type delta s) (get ai (i-1)))
in
set ai i ss;
tallying i;
......@@ -3674,8 +3697,8 @@ exception FoundApply of t * int * int * Tallying.CS.sl
let apply_raw delta s t =
Tallying.NormMemoHash.clear Tallying.memo_norm;
let s = Positive.substitute_kind delta Var.function_kind s in
let t = Positive.substitute_kind delta Var.argument_kind t in
let s = Substitution.substitute_kind delta Var.function_kind s in
let t = Substitution.substitute_kind delta Var.argument_kind t in
let vgamma = Var.mk "Gamma" in
let gamma = var vgamma in
let cgamma = cons gamma in
......@@ -3688,7 +3711,7 @@ let apply_raw delta s t =
let t = arrow (cons (get aj j)) cgamma in
let sl = Tallying.tallying delta [ (s,t) ] in
let new_res =
Positive.clean_type delta (
Substitution.clean_type delta (
List.fold_left (fun tacc si ->
cap tacc (Tallying.(gamma >> si))
) any sl
......@@ -3704,8 +3727,8 @@ let apply_raw delta s t =
(* Format.printf "Starting expansion %i @\n@." i; *)
let (ss,tt) =
if i = 0 then (s,t) else
((cap (Positive.substitute_free delta s) (get ai (i-1))),
(cap (Positive.substitute_free delta t) (get aj (i-1))))
((cap (Substitution.refresh_type delta s) (get ai (i-1))),
(cap (Substitution.refresh_type delta t) (get aj (i-1))))
in
set ai i ss;
set aj i tt;
......
......@@ -155,13 +155,16 @@ module Positive : sig
val xml: v -> v -> v
val solve: v -> Node.t
val substitute : t -> (Var.var * t) -> t
val substitute_list : t -> (Var.var * t) list -> t
val solve_rectype : t -> Var.var -> t
val substitute_free : Var.Set.t -> t -> t
val clean_type : Var.Set.t -> t -> t
end
module Substitution :
sig
val apply : t -> (Var.var * t) list -> t
val apply_single : t -> (Var.var * t) -> t
val refresh_type : Var.Set.t -> t -> t
val solve_fixpoint : t -> Var.var -> t
end
(** Normalization **)
module Product : sig
......
......@@ -548,7 +548,7 @@ module IType = struct
(Printf.sprintf "Wrong number of parameters for parametric type %s" (U.to_string id));
| Error s -> raise_loc_generic loc s
in
mk_type (Types.Positive.substitute_list t l)
mk_type ((*Types.Positive.substitute_list*) Types.Substitution.apply t l)
with Not_found ->
assert (rest == []);
if args != [] then
......@@ -618,15 +618,20 @@ module IType = struct
raise_loc_generic loc
(Printf.sprintf "Definition of type %s contains unbound type variables"
(Ident.to_string v));
let nargs = List.map (fun x ->
let v = (Var.mk (U.to_string x)) in v, Types.var v
) args in
(*
Not needed ?
let vars_mapping = (* create a sequence 'a -> 'a_0 for all variables *)
List.map (fun v -> let vv = Var.mk (U.to_string v) in vv, Var.fresh vv) args
in
let sub_list = List.map (fun (v,vt) -> v, Types.var vt) vars_mapping in
let t_rhs =
Types.Positive.substitute_list t_rhs sub_list
(*Types.Positive.substitute_list t_rhs*)Types.Substitution.apply t_rhs sub_list
in
let nargs = List.map2 (fun (_, v) (_, vt) -> v, vt) vars_mapping sub_list
in
in *)
(v,t_rhs,nargs)
) (List.rev b)
in
......@@ -662,7 +667,7 @@ module IType = struct
current_params := (idx,params,map);
type_defs env b
) env b in
clean_params (); r
clean_params ();r
with exn -> clean_on_err (); raise exn
let typ env t =
......@@ -1140,8 +1145,8 @@ and type_check' loc env ed constr precise = match ed with
(fun v ->
let open Types in
match v with
| Val t -> Val (Positive.substitute_free env.delta t)
| EVal (a,b,t) -> EVal (a,b,Positive.substitute_free env.delta t)
| Val t -> Val (Substitution.refresh_type env.delta t)
| EVal (a,b,t) -> EVal (a,b,Substitution.refresh_type env.delta t)
| x -> x)
env.ids }
in
......@@ -1217,7 +1222,7 @@ and type_check' loc env ed constr precise = match ed with
| Apply (e1,e2) ->
let t1 = type_check env e1 Types.Arrow.any true in
let t1arrow = Types.Arrow.get t1 in
let t1 = Types.Positive.substitute_free env.delta t1 in
let t1 = Types.Substitution.refresh_type env.delta t1 in
(* t [_delta 0 -> 1 *)
begin try
ignore(Types.Tallying.tallying env.delta [(t1,Types.Arrow.any)])
......@@ -1227,7 +1232,7 @@ and type_check' loc env ed constr precise = match ed with
let dom = Types.Arrow.domain(t1arrow) in
let t2 = type_check env e2 Types.any true in
let t2 = Types.Positive.substitute_free env.delta t2 in
let t2 = Types.Substitution.refresh_type env.delta t2 in
let (sl,res) =
if not (Types.no_var dom) ||
not (Types.no_var t2) then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment