Commit 7932c071 authored by Kim Nguyễn's avatar Kim Nguyễn
Browse files

Improve the complexity of constraint normalisation by using a hash

table instead of a 'recursion set' like in the paper.  The hash table
is tricky:

 - its keys should be both the type we are normalizing and
delta

 - we should store a flag together with the constraint set, indicating
whether the computation of the normalization for the corresponding
type has finished. If it has, we can use the associated constraint set
instead of CS.sat and stop recursion. If it has not, we can return
CS.sat (like the previous base case).

We also update the test case files to check that everything is in order:

- part2.cd has been rewritten to make use of the new syntax and remove
  the red-black trees examples that are now in a separate file

- red-black.cd is a fully typechecking file
- rb-fail.cd has the type definition and the wrong balance function.
parent 42dda81a
...@@ -26,23 +26,6 @@ let max (x : 'a) (y : 'a) : 'a = if x >> y then x else y;; ...@@ -26,23 +26,6 @@ let max (x : 'a) (y : 'a) : 'a = if x >> y then x else y;;
max 42;; max 42;;
type RBtree = Btree | Rtree
type Btree = <blk elem='a>[ RBtree RBtree ] | []
type Rtree = <red elem='a>[ Btree Btree ]
type Unbal = <blk elem='a>( [ Wrong RBtree ]
| [ RBtree Wrong ])
type Wrong = <red elem='a>( [ Rtree Btree ]
| [ Btree Rtree ])
let balance ( Unbal ->Rtree ; ('b \ Unbal) ->('b \ Unbal) )
| <blk (z)>[ <red (y)>[ <red (x)>[ a b ] c ] d ]
| <blk (z)>[ <red (x)>[ a <red (y)>[ b c ] ] d ]
| <blk (x)>[ a <red (z)>[ <red (y)>[ b c ] d ] ]
| <blk (x)>[ a <red (y)>[ b <red (z)>[ c d ] ] ]
-> <red (y)>[ <blk (x)>[ a b ] <blk (z)>[ c d ] ]
| x -> x
;;
let r = balance <blk elem=1>[ <red elem=1>[ <red elem=1>[ 1 2] 3 ]4];;
(* some tricky examples *) (* some tricky examples *)
...@@ -54,27 +37,19 @@ let x = id even (even mmap) even;; (* same type as map_even *) ...@@ -54,27 +37,19 @@ let x = id even (even mmap) even;; (* same type as map_even *)
let twisted = id even (even mmap) even (mmap max [1 2 3 4 5 6]);; let twisted = id even (even mmap) even (mmap max [1 2 3 4 5 6]);;
let apply_to_3 (f: Int -> 'a): 'a = f 3 in let apply_to_3 (f: Int -> 'a): 'a = f 3 in
mmap apply_to_3 twisted mmap apply_to_3 twisted
;; ;;
type A = <a>'a type A('a) = <a>'a
type B = <b>[(A|B)];; type B('a) = <b>[(A('a)|B('a))];;
let f (_ : 'a -> 'a -> 'a)(z : 'a)(_ : A|B) : A = <a>z;;
let f (_ : 'a -> 'a -> 'a)(z : 'a)(_ : A('a)|B('a)) : A('a) = <a>z;;
let sum (x : Int) (y : Int) : Int = x + y;;
let x = f sum;; let x = f sum;;
(* Some expressions that are ill typed *) (* Some expressions that are ill typed *)
let balance (Unbal ->Rtree ; 'a -> 'a )
| <blk (z)>[ <red (y)>[ <red (x)>[ a b ] c ] d ]
| <blk (z)>[ <red (x)>[ a <red (y)>[ b c ] ] d ]
| <blk (x)>[ a <red (z)>[ <red (y)>[ b c ] d ] ]
| <blk (x)>[ a <red (y)>[ b <red (z)>[ c d ] ] ]
-> <red (y)>[ <blk (x)>[ a b ] <blk (z)>[ c d ] ]
| x -> x
;;
let id ('a -> 'a) let id ('a -> 'a)
Int -> "foo" Int -> "foo"
| x -> x | x -> x
......
type RBtree('a) = Btree('a) | Rtree('a)
(* Black rooted RB tree: *)
type Btree('a) = [] | <black elem='a>[ RBtree('a) RBtree('a) ]
(* Red rooted RB tree: *)
type Rtree('a) = <red elem='a>[ Btree('a) Btree('a) ]
type Wrongtree('a) = <red elem='a>( [ Rtree('a) Btree('a) ]
| [ Btree('a) Rtree('a) ])
type Unbalanced('a) = <black elem='a>( [ Wrongtree('a) RBtree('a) ]
| [ RBtree('a) Wrongtree('a) ])
;;
let balance ( Unbalanced('a) -> Rtree('a) ; 'b\Unbalanced('a) -> 'b\Unbalanced('a) )
| <black (z)>[ <red (y)>[ <red (x)>[ a b ] c ] d ]
| <black (z)>[ <red (x)>[ a <red (y)>[ b c ] ] d ]
| <black (x)>[ a <red (z)>[ <red (y)>[ b c ] d ] ]
| <black (x)>[ a <red (y)>[ b <red (z)>[ c d ] ] ] ->
<red (y)>[ <black (x)>[ a b ] <black (z)>[ c d ] ]
| x -> x
;;
...@@ -14,7 +14,10 @@ type Unbalanced('a) = <black elem='a>( [ Wrongtree('a) RBtree('a) ] ...@@ -14,7 +14,10 @@ type Unbalanced('a) = <black elem='a>( [ Wrongtree('a) RBtree('a) ]
| [ RBtree('a) Wrongtree('a) ]) | [ RBtree('a) Wrongtree('a) ])
;; ;;
(* does not type *) (***************
ill typed, see rb-fail.cd
let balance ( Unbalanced('a) -> Rtree('a) ; 'b\Unbalanced('a) -> 'b\Unbalanced('a) ) let balance ( Unbalanced('a) -> Rtree('a) ; 'b\Unbalanced('a) -> 'b\Unbalanced('a) )
| <black (z)>[ <red (y)>[ <red (x)>[ a b ] c ] d ] | <black (z)>[ <red (y)>[ <red (x)>[ a b ] c ] d ]
| <black (z)>[ <red (x)>[ a <red (y)>[ b c ] ] d ] | <black (z)>[ <red (x)>[ a <red (y)>[ b c ] ] d ]
...@@ -23,6 +26,8 @@ let balance ( Unbalanced('a) -> Rtree('a) ; 'b\Unbalanced('a) -> 'b\Unbalanced(' ...@@ -23,6 +26,8 @@ let balance ( Unbalanced('a) -> Rtree('a) ; 'b\Unbalanced('a) -> 'b\Unbalanced('
<red (y)>[ <black (x)>[ a b ] <black (z)>[ c d ] ] <red (y)>[ <black (x)>[ a b ] <black (z)>[ c d ] ]
| x -> x | x -> x
;; ;;
************)
let balance ( Unbalanced('a) -> Rtree('a) ; 'b & RBtree('a) -> 'b & RBtree('a) ) let balance ( Unbalanced('a) -> Rtree('a) ; 'b & RBtree('a) -> 'b & RBtree('a) )
| <black (z)>[ <red (y)>[ <red (x)>[ a b ] c ] d ] | <black (z)>[ <red (y)>[ <red (x)>[ a b ] c ] d ]
......
...@@ -3230,44 +3230,58 @@ module Tallying = struct ...@@ -3230,44 +3230,58 @@ module Tallying = struct
(* norm generates a constraint set for the costraint t <= 0 *) (* norm generates a constraint set for the costraint t <= 0 *)
module NormMemoHash = Hashtbl.Make(Custom.Pair(Descr)(Var.Set))
let rec norm (t,delta,mem) = let rec norm (t,delta,mem) =
(* if we already evaluated it, it is sat *) if is_empty t then CS.sat
let res = else
if DescrSet.mem t mem || is_empty t then CS.sat try
else if is_var t then begin let finished, cst = NormMemoHash.find mem (t, delta) in
(* if there is only one variable then is it A <= 0 or 1 <= A *) if finished then cst else CS.sat
let (v,p) = extract_variable t in with
if Var.Set.mem v delta then CS.unsat Not_found ->
else begin
let s = if p then (Pos (v,empty)) else (Neg (any,v)) in let res =
CS.singleton s if is_var t then
(* if there are no vars, and it is not empty then unsat *) begin
end else if no_var t then CS.unsat (* if there is only one variable then is it A <= 0 or 1 <= A *)
else begin let (v,p) = extract_variable t in
let mem = DescrSet.add t mem in if Var.Set.mem v delta then CS.unsat
let aux single norm_aux acc l = else
big_prod delta (toplevel delta single norm_aux mem) acc l let s = if p then (Pos (v,empty)) else (Neg (any,v)) in
in CS.singleton s
let acc = aux single_atoms normatoms CS.sat (BoolAtoms.get t.atoms) in (* if there are no vars, and it is not empty then unsat *)
let acc = aux single_chars normchars acc (BoolChars.get t.chars) in end
let acc = aux single_ints normints acc (BoolIntervals.get t.ints) in else if no_var t then CS.unsat
let acc = aux single_times normpair acc (BoolPair.get t.times) in else begin
let acc = aux single_xml normpair acc (BoolPair.get t.xml) in let mem = NormMemoHash.add mem (t,delta) (false, CS.sat); mem in
let acc = aux single_arrow normarrow acc (BoolPair.get t.arrow) in let aux single norm_aux acc l =
let acc = aux single_abstract normabstract acc (BoolAbstracts.get t.abstract) in big_prod delta (toplevel delta single norm_aux mem) acc l
(* XXX normrec is not tested at all !!! *) in
let res = aux single_record normrec acc (BoolRec.get t.record) in let acc = aux single_atoms normatoms CS.sat (BoolAtoms.get t.atoms) in
let res = (* Simplify the constraints on that type *) let acc = aux single_chars normchars acc (BoolChars.get t.chars) in
CS.S.filter let acc = aux single_ints normints acc (BoolIntervals.get t.ints) in
(fun m -> CS.M.for_all (fun v (s, t) -> not (Var.Set.mem v delta) || let acc = aux single_times normpair acc (BoolPair.get t.times) in
let x = var v in subtype s x && subtype x t let acc = aux single_xml normpair acc (BoolPair.get t.xml) in
) m) let acc = aux single_arrow normarrow acc (BoolPair.get t.arrow) in
res let acc = aux single_abstract normabstract acc (BoolAbstracts.get t.abstract) in
in (* XXX normrec is not tested at all !!! *)
res let acc = aux single_record normrec acc (BoolRec.get t.record) in
end let acc = (* Simplify the constraints on that type *)
in CS.S.filter
(* Format.printf "Normalizing %a yields %a\n%!" Print.pp_type t CS.pp_s res; *) res (fun m -> CS.M.for_all (fun v (s, t) -> not (Var.Set.mem v delta) ||
let x = var v in subtype s x && subtype x t
) m)
acc
in
acc
end
in
NormMemoHash.replace mem (t, delta) (true,res); res
end
(* Format.printf "Normalizing %a yields %a\n%!" Print.pp_type t CS.pp_s res; *)
(* (t1,t2) = intersection of all (fst pos,snd pos) \in P (* (t1,t2) = intersection of all (fst pos,snd pos) \in P
* (s1,s2) \in N * (s1,s2) \in N
...@@ -3362,15 +3376,13 @@ module Tallying = struct ...@@ -3362,15 +3376,13 @@ module Tallying = struct
in in
big_prod delta norm_arrow CS.sat (Pair.get t) big_prod delta norm_arrow CS.sat (Pair.get t)
module NormMemoHash = Hashtbl.Make(Custom.Pair(Descr)(Var.Set))
let memo_norm = NormMemoHash.create 17 let memo_norm = NormMemoHash.create 17
let norm delta t = let norm delta t =
try NormMemoHash.find memo_norm (t,delta) try NormMemoHash.find memo_norm (t,delta)
with Not_found -> begin with Not_found -> begin
let res = norm (t,delta,DescrSet.empty) in let res = norm (t,delta,NormMemoHash.create 17) in
NormMemoHash.add memo_norm (t,delta) res; res NormMemoHash.add memo_norm (t,delta) res; res
end end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment