Commit 521779d3 authored by Pietro Abate's avatar Pietro Abate
Browse files

Better preter for types

parent e3fd6338
......@@ -21,7 +21,7 @@ module ESet = OUnitDiff.SetMake (struct
if (v1,t1) == (v2,t2) then 0
else let c = Var.compare v1 v2 in if c <> 0 then c
else Types.compare (diff t1 a) (diff t2 a)
let pp_printer ppf (v,t) = Format.fprintf ppf "(%a = %s)" Var.dump v (to_string Print.print t)
let pp_printer ppf (v,t) = Format.fprintf ppf "(%a = %s)" Var.print v (to_string Print.print t)
let pp_print_sep = OUnitDiff.pp_comma_separator
end)
......
......@@ -139,16 +139,29 @@ module BoolChars : BoolVar.S with
module TLV = struct
module Set = Set.Make(
struct
type t = (Var.var * bool)
let compare (v1,p1) (v2,p2) =
let c = Var.compare v1 v2 in
if c == 0 then
if p1 == p2 then 0
else if p1 then 1 else -1
else c
end)
module Set = struct
include Set.Make(
struct
type t = (Var.var * bool)
let compare (v1,p1) (v2,p2) =
let c = Var.compare v1 v2 in
if c == 0 then
if p1 == p2 then 0
else if p1 then 1 else -1
else c
end)
let print sep ppf s =
let aux1 ppf = function
|(v,true) -> Format.fprintf ppf "%a" Var.print v
|(v,false) -> Format.fprintf ppf "~ %a" Var.print v
in
let rec aux ppf = function
|[] -> ()
|[h] -> aux1 ppf h
|h::l -> Format.fprintf ppf "%a %s %a" aux1 h sep aux l
in
aux ppf (elements s)
end
(* s : top level variables
* f : all free variables in the subtree
......@@ -189,8 +202,14 @@ module TLV = struct
(* true if it contains only one variable *)
let is_single x = x.b && (Var.Set.cardinal x.f = 1) && (Set.cardinal x.s = 1)
let no_variables x = (x.b == false) && (Var.Set.cardinal x.f = 0) && (Set.cardinal x.s = 0)
let no_toplevel x = (Set.cardinal x.s = 0)
let print sep ppf x = if x.b then Set.print sep ppf x.s
let mem v x = Set.mem v x.s
end
module rec Descr :
......@@ -478,7 +497,7 @@ let cap x y =
chars = BoolChars.cap x.chars y.chars;
abstract = Abstract.cap x.abstract y.abstract;
absent= x.absent && y.absent;
toplvars = TLV.inter x.toplvars y.toplvars
toplvars = TLV.union x.toplvars y.toplvars
}
let diff x y =
......@@ -1098,7 +1117,7 @@ struct
(t & t1, s - s1) | ... | (t & tn, s - sn) | (t - (t1|...|tn), s)
*)
let get_aux any_right d =
let partition any_right d =
let accu = ref [] in
let line (left,right) =
let (d1,d2) = cap_product any any_right left in
......@@ -1117,7 +1136,7 @@ struct
) right in
if non_empty !resid1 then accu := (!resid1, d2) :: !accu
in
List.iter line (Pair.get (BoolPair.leafconj d));
List.iter line (Pair.get d);
!accu
(* Maybe, can improve this function with:
(t,s) \ (t1,s1) = (t&t',s\s') | (t\t',s),
......@@ -1125,13 +1144,8 @@ struct
let get ?(kind=`Normal) d =
match kind with
| `Normal -> get_aux any d.times
| `XML -> get_aux any_pair d.xml
let getpair ?(kind=`Normal) d =
match kind with
| `Normal -> d.times
| `XML -> d.xml
| `Normal -> partition any (BoolPair.leafconj d.times)
| `XML -> partition any_pair (BoolPair.leafconj d.xml)
let pi1 = List.fold_left (fun acc (t1,_) -> cup acc t1) empty
let pi2 = List.fold_left (fun acc (_,t2) -> cup acc t2) empty
......@@ -1147,7 +1161,7 @@ struct
type normal = t
module Memo = Map.Make(BoolPair)
module Memo = Map.Make(Pair)
(* TODO: try with an hashtable *)
(* Also, avoid lookup for simple products (t1,t2) *)
......@@ -1156,7 +1170,7 @@ struct
try Memo.find d !memo
with
Not_found ->
let gd = get_aux any d in
let gd = partition any d in
let n = normal_aux gd in
(* Could optimize this call to normal_aux because one already
know that each line is normalized ... *)
......@@ -1168,15 +1182,15 @@ struct
try Memo.find d !memo_xml
with
Not_found ->
let gd = get_aux any_pair d in
let gd = partition any_pair d in
let n = normal_aux gd in
memo_xml := Memo.add d n !memo_xml;
n
let normal ?(kind=`Normal) d =
match kind with
| `Normal -> normal_times d.times
| `XML -> normal_xml d.xml
| `Normal -> normal_times (BoolPair.leafconj d.times)
| `XML -> normal_xml (BoolPair.leafconj d.xml)
(*
......@@ -1469,6 +1483,7 @@ struct
| Xml of [ `Tag of (Format.formatter -> unit) | `Type of nd ] * nd * nd
| Record of (bool * nd) label_map * bool * bool
| Arrows of (nd * nd) list * (nd * nd) list
| Intersection of nd
| Neg of nd
| Abs of nd
let compare x y = x.id - y.id
......@@ -1555,72 +1570,120 @@ struct
let d = lookup d in
try DescrHash.find memo d
with Not_found ->
try
let n = DescrMap.find d !named in
let s = alloc [] in
s.state <- `GlobalName n;
DescrHash.add memo d s;
s
with Not_found ->
if d.absent then alloc [Abs (prepare ({d with absent=false}))]
else if worth_complement d
then alloc [Neg (prepare (neg d))]
else let slot = alloc [] in
if not (worth_abbrev d) then slot.state <- `Expand;
DescrHash.add memo d slot;
try begin
let n = DescrMap.find d !named in
let s = alloc [] in
s.state <- `GlobalName n;
DescrHash.add memo d s;
s
end with Not_found ->
if d.absent then
alloc [Abs (prepare ({d with absent=false}))]
else if worth_complement d then
alloc [Neg (prepare (neg d))]
else
let slot = alloc [] in
if not (worth_abbrev d) then slot.state <- `Expand;
DescrHash.add memo d slot;
let (seq,not_seq) =
if (subtype { empty with times = d.times } seqs_descr) then
(cap d seqs_descr, diff d seqs_descr)
else
(empty, d) in
(empty, d)
in
let add u = slot.def <- u :: slot.def in
let prepare_boolvar ?(t=false) get print tlv bdd =
List.iter (fun (p,n) ->
let l1 =
List.fold_left (fun acc -> function
|(`Var v) as x ->
begin match (t, (TLV.mem (x,true) tlv)) with
|(true,true)
|(_,false) -> (Atomic (fun ppf -> Var.print ppf x))::acc
|(false,true) -> acc end
|`Atm bdd -> (print bdd) @ acc
) [] p
in
let l2 =
List.fold_left (fun acc -> function
|(`Var v) as x ->
begin match (t, (TLV.mem (x,false) tlv)) with
|(true,true)
|(_,false) -> (Atomic (fun ppf -> Format.fprintf ppf "~ %a" Var.print x))::acc
|(false,true) -> acc end
|`Atm bdd -> assert false
) [] n
in
match (l1@l2) with
|[] -> ()
|l -> add (Intersection (alloc (List.rev l)))
) (get bdd)
in
if (non_empty seq) then add (Regexp (decompile seq));
(* base types *)
prepare_boolvar ~t:true BoolIntervals.get (fun x ->
List.map (fun x -> (Atomic x)) (Intervals.print x)
) not_seq.toplvars not_seq.ints;
prepare_boolvar BoolChars.get (fun x ->
match Chars.is_char x with
| Some c -> [(Char c)]
| None -> List.map (fun x -> (Atomic x)) (Chars.print x)
) not_seq.toplvars not_seq.chars;
prepare_boolvar BoolAtoms.get (fun x ->
List.map (fun x -> (Atomic x)) (Atoms.print x)
) not_seq.toplvars not_seq.atoms;
(* pairs *)
List.iter (fun (t1,t2) ->
add (Pair (prepare t1, prepare t2))
) (Product.get not_seq);
prepare_boolvar BoolPair.get (fun x ->
List.map (fun (t1,t2) ->
(Pair (prepare t1, prepare t2))
) (Product.partition any x)) not_seq.toplvars not_seq.times;
(* xml pairs *)
List.iter (fun (t1,t2) ->
try
let n = DescrPairMap.find (t1,t2) !named_xml in
add (Name n)
with Not_found ->
let tag =
match Atoms.print_tag (BoolAtoms.leafconj t1.atoms) with
| Some a when is_empty { t1 with atoms = BoolAtoms.empty } -> `Tag a
| _ -> `Type (prepare t1) in
assert (equal { t2 with times = empty.times } empty);
List.iter
(fun (ta,tb) ->
add (Xml (tag, prepare ta, prepare tb)))
(Product.get t2);
) (Product.get ~kind:`XML not_seq);
prepare_boolvar BoolPair.get (fun x ->
List.flatten (
List.map (fun (t1,t2) ->
try let n = DescrPairMap.find (t1,t2) !named_xml in [(Name n)]
with Not_found ->
let tag =
match Atoms.print_tag (BoolAtoms.leafconj t1.atoms) with
| Some a when is_empty { t1 with atoms = BoolAtoms.empty } -> `Tag a
| _ -> `Type (prepare t1)
in
assert (equal { t2 with times = empty.times } empty);
List.map (fun (ta,tb) ->
(Xml (tag, prepare ta, prepare tb))
) (Product.get t2);
) (Product.partition any_pair x)
)) not_seq.toplvars not_seq.xml;
(* arrows *)
prepare_boolvar BoolPair.get (fun x ->
List.map (fun (p,n) ->
let aux (t,s) = prepare (descr t), prepare (descr s) in
let p = List.map aux p and n = List.map aux n in
(Arrows (p,n))
) (Pair.get x)) not_seq.toplvars not_seq.arrow;
(* records *)
List.iter (fun (r,some,none) ->
let r = LabelMap.map (fun (o,t) -> (o, prepare t)) r in
add (Record (r,some,none))
) (Record.get not_seq);
(match Chars.is_char (BoolChars.leafconj not_seq.chars) with
| Some c -> add (Char c)
| None ->
List.iter (fun x -> add (Atomic x)) (BoolChars.print not_seq.chars));
List.iter (fun x -> add (Atomic x)) (BoolIntervals.print not_seq.ints);
List.iter (fun x -> add (Atomic x)) (BoolAtoms.print not_seq.atoms);
(*
prepare_boolvar BoolRec.get (fun x ->
List.iter (fun (r,some,none) ->
let r = LabelMap.map (fun (o,t) -> (o, prepare t)) r in
add (Record (r,some,none))
) (Record.get x)) not_seq.toplvars not_seq;
*)
List.iter (fun x -> add (Atomic x)) (Abstract.print not_seq.abstract);
(* arrows *)
List.iter (fun (p,n) ->
let aux (t,s) = prepare (descr t), prepare (descr s) in
let p = List.map aux p and n = List.map aux n in
add (Arrows (p,n))
) (Pair.get (BoolPair.leafconj not_seq.arrow));
if not_seq.absent then add (Atomic (fun ppf -> Format.fprintf ppf "#ABSENT"));
slot.def <- List.rev slot.def;
slot
......@@ -1672,6 +1735,7 @@ struct
| Name _ | Char _ | Atomic _ -> ()
| Regexp r -> assign_name_regexp r
| Pair (t1,t2) -> assign_name t1; assign_name t2
| Intersection t -> assign_name t
| Xml (tag,t2,t3) ->
(match tag with `Type t -> assign_name t | _ -> ());
assign_name t2;
......@@ -1691,16 +1755,16 @@ struct
let print_gname ppf (cu,n) =
Format.fprintf ppf "%s%a" cu Ns.QName.print n
let rec do_print_slot pri ppf s =
let rec do_print_slot ?(sep="|") pri ppf s =
match s.state with
| `Named n -> U.print ppf n
| `GlobalName n -> print_gname ppf n
| _ -> do_print_slot_real pri ppf s.def
and do_print_slot_real pri ppf def =
| _ -> do_print_slot_real ~sep pri ppf s.def
and do_print_slot_real ?(sep="|") pri ppf def =
let rec aux ppf = function
| [] -> Format.fprintf ppf "Empty"
| [ h ] -> (do_print pri) ppf h
| h :: t -> Format.fprintf ppf "%a |@ %a" (do_print pri) h aux t
| h :: t -> Format.fprintf ppf "%a %s@ %a" (do_print pri) h sep aux t
in
if (pri >= 2) && (List.length def >= 2)
then Format.fprintf ppf "@[(%a)@]" aux def
......@@ -1714,6 +1778,7 @@ struct
| Char c -> Chars.V.print ppf c
| Regexp r -> Format.fprintf ppf "@[[ %a ]@]" (do_print_regexp 0) r
| Atomic a -> a ppf
| Intersection a -> Format.fprintf ppf "@[%a@]" (do_print_slot ~sep:"&" 0) a
| Pair (t1,t2) ->
Format.fprintf ppf "@[(%a,%a)@]"
(do_print_slot 0) t1
......@@ -2133,6 +2198,7 @@ struct
decompose_aux atom (BoolAtoms.get t.atoms);
decompose_aux interval (BoolIntervals.get t.ints);
decompose_aux char (BoolChars.get t.chars);
(* XXX XXX record is not threated here yet !!! *)
decompose_aux ~noderec:(subpairs arrow)
(fun p -> { empty with arrow = BoolPair.atom (`Atm p) }) (BoolPair.get t.arrow);
decompose_aux ~noderec:(subpairs xml)
......@@ -2211,13 +2277,7 @@ struct
end
in
let x = forward () in
Printf.printf "before decompose \n%!";
let s = decompose t in
Printf.printf "after decompose \n%!";
let mux = (substitute_aux s (subst h)) in
Printf.printf "after substitute \n%!";
define x mux;
Printf.printf "after define \n%!";
define x (substitute_aux (decompose t) (subst h));
descr(solve x)
end
......@@ -2296,8 +2356,6 @@ struct
in
aux (Pair.get (BoolPair.leafconj s.arrow))
let getpair d = d.arrow
type t = descr * (descr * descr) list list
let get t =
......@@ -2319,17 +2377,17 @@ struct
let rec aux result accu1 accu2 = function
| (t1,s1)::left ->
let result =
let accu1 = diff accu1 t1 in
let accu1 = diff accu1 t1 in
if non_empty accu1 then aux result accu1 accu2 left
else result in
else result
in
let result =
let accu2 = cap accu2 s1 in
aux result accu1 accu2 left in
result
let accu2 = cap accu2 s1 in
aux result accu1 accu2 left
in
result
| [] ->
if subtype accu2 result
then result
else cup result accu2
if subtype accu2 result then result else cup result accu2
in
aux result t any left
......@@ -2489,6 +2547,7 @@ module Tallying = struct
type m = Descr.s M.t
type e = Descr.s E.t
type es = ES.t
type sl = (Var.var * t) list list
let singleton = function
|Pos (v,s) -> S.singleton (M.singleton (true,v) s)
......@@ -2871,7 +2930,7 @@ exception KeepGoing
let apply t1 t2 =
DescrHash.clear Tallying.memo_norm;
let q = Queue.create () in
let gamma = var (Var.fresh ~pre:"gamma" ~variance:`ContraVariant ()) in
let gamma = var (Var.mk ~variance:`ContraVariant "gamma") in
let rec aux (i,acc1) (j,acc2) t1 t2 () =
let acc1 = Lazy.force acc1 and acc2 = Lazy.force acc2 in
try Tallying.tallying [(acc1,arrow (cons acc2) (cons gamma))]
......@@ -2882,7 +2941,6 @@ let apply t1 t2 =
Queue.add (aux (i,lazy(acc1)) (j+1,lazy(cap acc2 (Positive.substitutefree t2))) t1 t2) q;
raise KeepGoing
end
in
Queue.add (aux (0,lazy(t1)) (1,lazy(Positive.substitutefree t2)) t1 t2) q;
Queue.add (aux (1,lazy(Positive.substitutefree t1)) (0,lazy(t2)) t1 t2) q;
......
......@@ -181,7 +181,6 @@ module Product : sig
type t = (descr * descr) list
val is_empty: t -> bool
val get: ?kind:pair_kind -> descr -> t
val getpair: ?kind:pair_kind -> descr -> BoolPair.t
val pi1: t -> descr
val pi2: t -> descr
val pi2_restricted: descr -> t -> descr
......@@ -270,8 +269,6 @@ module Arrow : sig
val get: descr -> t
(* Always succeed; no check <= Arrow.any *)
val getpair: descr -> BoolPair.t
val domain: t -> descr
val apply: t -> descr -> descr
(* Always succeed; no check on the domain *)
......@@ -390,6 +387,7 @@ module Tallying : sig
type m = t M.t
type e = t E.t
type es = ES.t
type sl = (Var.var * t) list list
val print : Format.formatter -> s -> unit
val print_m : Format.formatter -> m -> unit
......@@ -407,8 +405,8 @@ module Tallying : sig
val merge : CS.m -> CS.s
val solve : CS.s -> CS.es
val unify : CS.e -> CS.e
val tallying : (t * t) list -> (Var.var * t) list list
val tallying : (t * t) list -> CS.sl
end
val apply : t -> t -> (Var.var * t) list list
val apply : t -> t -> Tallying.CS.sl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment