bool.ml 15.3 KB
Newer Older
1 2
module type S =
sig
3 4
  type elem
  include Custom.T
5

6
  val get: t -> (elem list * elem list) list
7

8 9 10 11 12 13
  val empty : t
  val full  : t
  val cup   : t -> t -> t
  val cap   : t -> t -> t
  val diff  : t -> t -> t
  val atom  : elem -> t
14

15
  val iter: (elem-> unit) -> t -> unit
16 17 18

  val compute: empty:'b -> full:'b -> cup:('b -> 'b -> 'b) 
    -> cap:('b -> 'b -> 'b) -> diff:('b -> 'b -> 'b) ->
19
    atom:(elem -> 'b) -> t -> 'b
20

21
  val print: string -> (Format.formatter -> elem -> unit) -> t ->
22
    (Format.formatter -> unit) list
23

24
  val trivially_disjoint: t -> t -> bool
25 26
end

27
module MakeOld(X : Custom.T) =
28
struct
29 30
  type elem = X.t
  type t =
31 32
    | True
    | False
33
    | Split of int * elem * t * t * t
34

35 36 37 38 39
  let rec equal a b =
    (a == b) ||
    match (a,b) with
      | Split (h1,x1, p1,i1,n1), Split (h2,x2, p2,i2,n2) ->
	  (h1 == h2) &&
40 41
	  (equal p1 p2) && (equal i1 i2) &&
	  (equal n1 n2) && (X.equal x1 x2)
42 43
      | _ -> false

44 45 46 47
(* Idea: add a mutable "unique" identifier and set it to
   the minimum of the two when egality ... *)


48 49 50 51 52 53 54 55 56 57 58 59 60
  let rec compare a b =
    if (a == b) then 0 
    else match (a,b) with
      | Split (h1,x1, p1,i1,n1), Split (h2,x2, p2,i2,n2) ->
	  if h1 < h2 then -1 else if h1 > h2 then 1 
	  else let c = X.compare x1 x2 in if c <> 0 then c
	  else let c = compare p1 p2 in if c <> 0 then c
	  else let c = compare i1 i2 in if c <> 0 then c 
	  else compare n1 n2
      | True,_  -> -1
      | _, True -> 1
      | False,_ -> -1
      | _,False -> 1
61 62


63 64 65 66 67 68 69 70
  let rec hash = function
    | True -> 1
    | False -> 0
    | Split(h, _,_,_,_) -> h

  let compute_hash x p i n = 
	(X.hash x) + 17 * (hash p) + 257 * (hash i) + 16637 * (hash n)

71 72 73 74 75 76 77 78

  let rec check = function
    | True | False -> ()
    | Split (h,x,p,i,n) ->
	assert (h = compute_hash x p i n);
	(match p with Split (_,y,_,_,_) -> assert (X.compare x y < 0) | _ -> ());
	(match i with Split (_,y,_,_,_) -> assert (X.compare x y < 0) | _ -> ());
	(match n with Split (_,y,_,_,_) -> assert (X.compare x y < 0) | _ -> ());
79
	X.check x; check p; check i; check n
80

81 82 83
  let atom x =
    let h = X.hash x + 17 in (* partial evaluation of compute_hash... *)
    Split (h, x,True,False,False)
84 85 86 87
 
  let neg_atom x =
    let h = X.hash x + 16637 in (* partial evaluation of compute_hash... *)
    Split (h, x,False,False,True)
88 89 90 91 92 93 94 95 96 97 98

  let rec iter f = function
    | Split (_, x, p,i,n) -> f x; iter f p; iter f i; iter f n
    | _ -> ()

  let rec dump ppf = function
    | True -> Format.fprintf ppf "+"
    | False -> Format.fprintf ppf "-"
    | Split (_,x, p,i,n) -> 
	Format.fprintf ppf "%a(@[%a,%a,%a@])" 
	X.dump x dump p dump i dump n
99 100


101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
  let rec print f ppf = function
    | True -> Format.fprintf ppf "Any"
    | False -> Format.fprintf ppf "Empty"
    | Split (_, x, p,i, n) ->
	let flag = ref false in
	let b () = if !flag then Format.fprintf ppf " | " else flag := true in
	(match p with 
	   | True -> b(); Format.fprintf ppf "%a" f x
	   | False -> ()
	   | _ -> b (); Format.fprintf ppf "%a & @[(%a)@]" f x (print f) p );
	(match i with 
	   | True -> assert false;
	   | False -> ()
	   | _ -> b(); print f ppf i);
	(match n with 
	   | True -> b (); Format.fprintf ppf "@[~%a@]" f x
	   | False -> ()
	   | _ -> b (); Format.fprintf ppf "@[~%a@] & @[(%a)@]" f x (print f) n)
	
  let print a f = function
    | True -> [ fun ppf -> Format.fprintf ppf "%s" a ]
    | False -> []
    | c -> [ fun ppf -> print f ppf c ]
	
	
  let rec get accu pos neg = function
    | True -> (pos,neg) :: accu
    | False -> accu
    | Split (_,x, p,i,n) ->
130
	(*OPT: can avoid creating this list cell when pos or neg =False *)
131 132 133 134 135 136 137 138 139 140 141 142 143 144
	let accu = get accu (x::pos) neg p in
	let accu = get accu pos (x::neg) n in
	let accu = get accu pos neg i in
	accu
	  
  let get x = get [] [] [] x
		
  let compute ~empty ~full ~cup ~cap ~diff ~atom b =
    let rec aux = function
      | True -> full
      | False -> empty
      | Split(_,x, p,i,n) ->
	  let p = cap (atom x) (aux p)
	  and i = aux i
145
	  and n = diff (aux n) (atom x) in
146 147 148 149 150 151 152 153 154
	  cup (cup p i) n
    in
    aux b
      
(* Invariant: correct hash value *)

  let split x pos ign neg =
    Split (compute_hash x pos ign neg, x, pos, ign, neg)

155 156 157 158

    


159 160
  let empty = False
  let full = True
161 162 163 164 165 166

(* Invariants:
     Split (x, pos,ign,neg) ==>  (ign <> True);   
     (pos <> False or neg <> False)
*)

167 168 169 170
  let split x pos ign neg =
    if ign = True then True 
    else if (pos = False) && (neg = False) then ign
    else split x pos ign neg
171 172


173 174
(* Invariant:
   - no ``subsumption'
175
*)
176

177
(*
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
  let rec simplify a b =
    if equal a b then False 
    else match (a,b) with
      | False,_ | _, True -> False
      | a, False -> a
      | True, _ -> True
      | Split (_,x1,p1,i1,n1), Split (_,x2,p2,i2,n2) ->
	  let c = X.compare x1 x2 in
	  if c = 0 then
	    let p1' = simplify (simplify p1 i2) p2 
	    and i1' = simplify i1 i2
	    and n1' = simplify (simplify n1 i2) n2 in
	    if (p1 != p1') || (n1 != n1') || (i1 != i1') 
	    then split x1 p1' i1' n1'
	    else a
	  else if c > 0 then
	    simplify a i2
	  else
	    let p1' = simplify p1 b 
	    and i1' = simplify i1 b
	    and n1' = simplify n1 b in
	    if (p1 != p1') || (n1 != n1') || (i1 != i1') 
	    then split x1 p1' i1' n1'
	    else a
202 203 204
*)


205 206 207
  let rec simplify a l =
    if (a = False) then False else simpl_aux1 a [] l
  and simpl_aux1 a accu = function
208 209
    | [] -> 
	if accu = [] then a else
210 211 212 213
	  (match a with
	     | True -> True
	     | False -> assert false
	     | Split (_,x,p,i,n) -> simpl_aux2 x p i n [] [] [] accu)
214 215 216
    | False :: l -> simpl_aux1 a accu l
    | True :: l -> False
    | b :: l -> if a == b then False else simpl_aux1 a (b::accu) l
217 218 219 220 221 222 223 224
  and simpl_aux2 x p i n ap ai an = function
    | [] -> split x (simplify p ap) (simplify i ai) (simplify n an)
    | (Split (_,x2,p2,i2,n2) as b) :: l ->
	let c = X.compare x2 x in
	if c < 0 then 
	  simpl_aux3 x p i n ap ai an l i2
	else if c > 0 then 
	  simpl_aux2 x p i n (b :: ap) (b :: ai) (b :: an) l
225
	else
226 227 228 229 230 231 232 233 234 235 236
	  simpl_aux2 x p i n (p2 :: i2 :: ap) (i2 :: ai) (n2 :: i2 :: an) l
    | _ -> assert false
  and simpl_aux3 x p i n ap ai an l = function
    | False -> simpl_aux2 x p i n ap ai an l
    | True -> assert false
    | Split (_,x2,p2,i2,n2) as b ->
	let c = X.compare x2 x in
	if c < 0 then 
	  simpl_aux3 x p i n ap ai an l i2
	else if c > 0 then 
	  simpl_aux2 x p i n (b :: ap) (b :: ai) (b :: an) l
237
	else
238
	  simpl_aux2 x p i n (p2 :: i2 :: ap) (i2 :: ai) (n2 :: i2 :: an) l
239 240 241

	  
    
242
  let split x p i n = 
243 244
    split x (simplify p [i]) i (simplify n [i])

245 246

  let rec ( ++ ) a b =
247 248
(*    if equal a b then a *)
    if a == b then a  
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
    else match (a,b) with
      | True, _ | _, True -> True
      | False, a | a, False -> a
      | Split (_,x1, p1,i1,n1), Split (_,x2, p2,i2,n2) ->
	  let c = X.compare x1 x2 in
	  if c = 0 then
	    split x1 (p1 ++ p2) (i1 ++ i2) (n1 ++ n2)
	  else if c < 0 then
	    split x1 p1 (i1 ++ b) n1
	  else
	    split x2 p2 (i2 ++ a) n2

(* Pseudo-Invariant:
   - pos <> neg
*)

  let split x pos ign neg =
    if equal pos neg then (neg ++ ign) else split x pos ign neg

(* seems better not to make ++ and this split mutually recursive;
   is the invariant still inforced ? *)

  let rec ( ** ) a b =
272 273
    (*    if equal a b then a *)
    if a == b then a
274 275 276 277 278 279 280 281 282
    else match (a,b) with
      | True, a | a, True -> a
      | False, _ | _, False -> False
      | Split (_,x1, p1,i1,n1), Split (_,x2, p2,i2,n2) ->
	  let c = X.compare x1 x2 in
	  if c = 0 then
(*	    split x1 
	      (p1 ** (p2 ++ i2) ++ (p2 ** i1))
	      (i1 ** i2)
283
	      (n1 ** (n2 ++ i2) ++ (n2 ** i1))  *)
284 285 286
	    split x1 
	      ((p1 ++ i1) ** (p2 ++ i2))
	      False
287
	      ((n1 ++ i1) ** (n2 ++ i2)) 
288 289 290 291 292
	  else if c < 0 then
	    split x1 (p1 ** b) (i1 ** b) (n1 ** b)
	  else
	    split x2 (p2 ** a) (i2 ** a) (n2 ** a)

293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
  let rec trivially_disjoint a b =
    if a == b then a = False
    else match (a,b) with
      | True, a | a, True -> a = False
      | False, _ | _, False -> true
      | Split (_,x1, p1,i1,n1), Split (_,x2, p2,i2,n2) ->
	  let c = X.compare x1 x2 in
	  if c = 0 then
(* try expanding -> p1 p2; p1 i2; i1 p2; i1 i2 ... *)
	    trivially_disjoint (p1 ++ i1) (p2 ++ i2) &&
	    trivially_disjoint (n1 ++ i1) (n2 ++ i2)
	  else if c < 0 then
	    trivially_disjoint p1 b &&
	    trivially_disjoint i1 b &&
	    trivially_disjoint n1 b
	  else
	    trivially_disjoint p2 a &&
	    trivially_disjoint i2 a &&
	    trivially_disjoint n2 a

313 314 315 316 317 318 319 320 321
  let rec neg = function
    | True -> False
    | False -> True
(*    | Split (_,x, p,i,False) -> split x False (neg (i ++ p)) (neg i)
    | Split (_,x, False,i,n) -> split x (neg i) (neg (i ++ n)) False 
    | Split (_,x, p,False,n) -> split x (neg p) (neg (p ++ n)) (neg n)  *)
    | Split (_,x, p,i,n) -> split x (neg (i ++ p)) False (neg (i ++ n))
	      
  let rec ( // ) a b =  
322 323
(*    if equal a b then False  *)
    if a == b then False 
324 325 326 327 328 329 330 331 332 333 334 335
    else match (a,b) with
      | False,_ | _, True -> False
      | a, False -> a
      | True, b -> neg b
      | Split (_,x1, p1,i1,n1), Split (_,x2, p2,i2,n2) ->
	  let c = X.compare x1 x2 in
	  if c = 0 then
	    split x1
	      ((p1 ++ i1) // (p2 ++ i2))
	      False
	      ((n1 ++ i1) // (n2 ++ i2))
	  else if c < 0 then
336
	    split x1 (p1 // b) (i1 // b) (n1 // b) 
337 338 339 340 341 342 343 344 345
(*	    split x1 ((p1 ++ i1)// b) False ((n1 ++i1) // b)  *)
	  else
	    split x2 (a // (i2 ++ p2)) False (a // (i2 ++ n2))
	      

  let cup = ( ++ )
  let cap = ( ** )
  let diff = ( // )

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
 let rec serialize t = function
    | (True | False) as b -> 
	Serialize.Put.bool t true; Serialize.Put.bool t (b = True)
    | Split (_,x,p,i,n) ->
	Serialize.Put.bool t false;
	X.serialize t x;
	serialize t p;
	serialize t i;
	serialize t n

  let rec cap_atom x pos a = (* Assume that x does not appear in a *)
    match a with
      | False -> False
      | True -> if pos then atom x else neg_atom x
      | Split (_,y,p,i,n) ->
	  let c = X.compare x y in
	  assert (c <> 0);
363
	  if (c < 0) then 
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
	    if pos then split x a False False
	    else split x False False a
	  else split y (cap_atom x pos p) (cap_atom x pos i) (cap_atom x pos n)


    
  let rec deserialize t =
    if Serialize.Get.bool t then
      if Serialize.Get.bool t then True else False
    else
      let x = X.deserialize t in
      let p = deserialize t in
      let i = deserialize t in
      let n = deserialize t in
      (cap_atom x true p) ++ i ++ (cap_atom x false n)
      (* split x p i n is not ok, because order of keys might have changed! *)
  
381
(*
382
  let diff x y =
383
    let d = diff x y in
384
    Format.fprintf Format.std_formatter "X = %a@\nY = %a@\nX\\Z = %a@\n"
385 386 387 388 389 390 391 392 393
      dump x dump y dump d;  
    d

  let cap x y =
    let d = cap x y in
    Format.fprintf Format.std_formatter "X = %a@\nY = %a@\nX**Z = %a@\n"
      dump x dump y dump d;  
    d
*)
394
end
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572


module Make(X : Custom.T) =
struct
  type elem = X.t
  type t =
    | Zero
    | One
    | Branch of elem * t * t * int * t
  type node = t

  let neg = function
    | Zero -> One | One -> Zero
    | Branch (_,_,_,_,neg) -> neg

  let id = function
    | Zero -> 0
    | One -> 1
    | Branch (_,_,_,id,_) -> id

(* Internalization + detection of useless branching *)
  let max_id = ref 2 (* Must be >= 2 *)
  module W = Weak.Make(
    struct
      type t = node
	  
      let hash = function
	| Zero | One -> assert false
	| Branch (v,yes,no,_,_) -> 
	    1 + 17*X.hash v + 257*(id yes) + 65537*(id no)

      let equal x y = (x == y) || match x,y with
	| Branch (v1,yes1,no1,id1,_), Branch (v2,yes2,no2,id2,_) ->
	    (id1 == 0 || id2 == 0) && X.equal v1 v2 && 
	      (yes1 == yes2) && (no1 == no2)
	| _ -> assert false
    end)
  let table = W.create 16383
  type branch = 
      { v : X.t; yes : node; no : node; mutable id : int; neg : branch }
  let mk v yes no =
    if yes == no then yes
    else
      let rec pos = Branch (v,yes,no,0,Branch (v,neg yes,neg no,0,pos)) in
      let x = W.merge table pos in
      let pos : branch = Obj.magic x in
      if (pos.id == 0) 
      then (let n = !max_id in
	    max_id := succ n;
	    pos.id <- n;
	    pos.neg.id <- (-n));
      x

  let atom v = mk v One Zero

  let dummy = Obj.magic (ref 0)
  let memo_size = 16383
  let memo_keys = Array.make memo_size (Obj.magic dummy)
  let memo_vals = Array.make memo_size (Obj.magic dummy)
  let memo_occ = Array.make memo_size 0

  let eg2 (x1,y1) (x2,y2) = x1 == x2 && y1 == y2
  let rec cup x1 x2 = if (x1 == x2) then x1 else match x1,x2 with
    | One, x | x, One -> One
    | Zero, x | x, Zero -> x
    | Branch (v1,yes1,no1,id1,neg1), Branch (v2,yes2,no2,id2,neg2) ->
	if (x1 == neg2) then One
	else
	  let k,h = 
	    if id1<id2 then (x1,x2),id1+65537*id2 else (x2,x1),id2+65537*id1 in
	  let h = (h land max_int) mod memo_size in
	  let i = memo_occ.(h) in
	  let k' = memo_keys.(h) in
	  if (k' != dummy) && (eg2 k k') 
	  then (memo_occ.(h) <- succ i; memo_vals.(h))
	  else 
	    let r = 
              let c = X.compare v1 v2 in
	      if (c = 0) then mk v1 (cup yes1 yes2) (cup no1 no2)
	      else if (c < 0) then mk v1 (cup yes1 x2) (cup no1 x2)
	      else mk v2 (cup yes2 x1) (cup no2 x1) in
	    if (i = 0) then (memo_keys.(h) <- k; memo_vals.(h) <- r;
			     memo_occ.(h) <- 1)
	    else memo_occ.(h) <- pred i;
	    r
  
	  

  let cap x1 x2 = neg (cup (neg x1) (neg x2))
  let diff x1 x2 = neg (cup (neg x1) x2)


  let rec iter f = function
    | Branch (x,p,n,_,_) -> f x; iter f p; iter f n
    | _ -> ()

  let rec dump ppf = function
    | One -> Format.fprintf ppf "+"
    | Zero -> Format.fprintf ppf "-"
    | Branch (x,p,n,_,_) -> 
	Format.fprintf ppf "%a(@[%a,%a@])" 
	X.dump x dump p dump n


  let rec print f ppf = function
    | One -> Format.fprintf ppf "Any"
    | Zero -> Format.fprintf ppf "Empty"
    | Branch (x,p,n,_,_) ->
	let flag = ref false in
	let b () = if !flag then Format.fprintf ppf " | " else flag := true in
	(match p with 
	   | One -> b(); Format.fprintf ppf "%a" f x
	   | Zero -> ()
	   | _ -> b (); Format.fprintf ppf "%a & @[(%a)@]" f x (print f) p );
	(match n with 
	   | One -> b (); Format.fprintf ppf "@[~%a@]" f x
	   | Zero -> ()
	   | _ -> b (); Format.fprintf ppf "@[~%a@] & @[(%a)@]" f x (print f) n)
	
  let print a f = function
    | One -> [ fun ppf -> Format.fprintf ppf "%s" a ]
    | Zero -> []
    | c -> [ fun ppf -> print f ppf c ]
	
  let rec get accu pos neg = function
    | One -> (pos,neg) :: accu
    | Zero -> accu
    | Branch (x,p,n,_,_) ->
	(*OPT: can avoid creating this list cell when pos or neg =False *)
	let accu = get accu (x::pos) neg p in
	let accu = get accu pos (x::neg) n in
	accu
	  
  let get x = get [] [] [] x
		
  let compute ~empty ~full ~cup ~cap ~diff ~atom b =
    let rec aux = function
      | One -> full
      | Zero -> empty
      | Branch(x,p,n,_,_) ->
	  let p = cap (atom x) (aux p)
	  and n = diff (aux n) (atom x) in
	  cup p n
    in
    aux b
      
  let empty = Zero
  let full = One

  let rec serialize t = function
    | (Zero | One) as b -> 
	Serialize.Put.bool t true; Serialize.Put.bool t (b = One)
    | Branch (x,p,n,_,_) ->
	Serialize.Put.bool t false;
	X.serialize t x;
	serialize t p;
	serialize t n

  let rec deserialize t =
    if Serialize.Get.bool t then
      if Serialize.Get.bool t then One else Zero
    else
      let x = X.deserialize t in
      let p = deserialize t in
      let n = deserialize t in

      let x = atom x in
      cup (cap x p) (cap (neg x) n)

      (* mk x p n is not ok, because order of keys might have changed!
	 OPT TODO: detect when this is ok *)

  let trivially_disjoint x y = neg x == y  
  let compare x y = compare (id x) (id y)
  let equal x y = x == y
  let hash x = id x
  let check x = ()
end