type_tallying.ml 22.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
open Types

let cap_t d t = cap d (descr t)

let cap_product any_left any_right l =
  List.fold_left
    (fun (d1,d2) (t1,t2) -> (cap_t d1 t1, cap_t d2 t2))
    (any_left,any_right)
    l

exception UnSatConstr of string

13
let pp_sl ppf l =
14
  Utils.pp_list ~delim:("{","}") Subst.print ppf l
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

type constr = Types.t * Types.t (* lower and
				   upper bounds *)

    (* A comparison function between types that
       is compatible with subtyping. If types are
       not in a subtyping relation, use implementation
       defined order
    *)

let compare_type t1 t2 =
  let inf12 = Types.subtype t1 t2 in
  let inf21 = Types.subtype t2 t1 in
  if inf12 && inf21 then 0
  else if inf12 then -1 else if inf21 then 1 else
      let c = Types.compare t1 t2 in
      assert (c <> 0);
      c


(* A line is a conjunction of constraints *)
36

37 38
module Line = struct

39
  type t = constr Var.Map.map
40

41 42 43
  let singleton = Var.Map.singleton
  let is_empty = Var.Map.is_empty
  let length = Var.Map.length
44 45 46 47 48 49 50 51 52

    (* a set of constraints m1 subsumes a set of constraints m2,
       that is the solutions for m1 contains all the solutions for
       m2 if:
       forall i1 <= v <= s1 in m1,
       there exists i2 <= v <= s2 in m2 such that i1 <= i2 <= v <= s2 <= s1
    *)
  let subsumes map1 map2 =
    List.for_all (fun (v,(i1, s1)) ->
53
      try let i2, s2 = Var.Map.assoc v map2 in
54 55
          subtype i1 i2 && subtype s2 s1
      with Not_found -> false
56
    ) (Var.Map.get map1)
57 58 59 60

  let print ppf map =
    Utils.pp_list ~delim:("{","}") (fun ppf (v, (i,s)) ->
      Format.fprintf ppf "%a <= %a <= %a" Print.pp_type i Var.print v Print.pp_type s
61
    ) ppf (Var.Map.get map)
62 63

  let compare map1 map2 =
64
    Var.Map.compare (fun (i1,s1) (i2,s2) ->
65 66 67 68 69 70 71
      let c = compare_type i1 i2 in
      if c == 0 then compare_type s1 s2
      else c) map1 map2

  let add v (inf, sup) map =
    let new_i, new_s =
      try
72
        let old_i, old_s = Var.Map.assoc v map in
73 74 75 76 77
        cup old_i inf,
        cap old_s sup
      with
        Not_found -> inf, sup
    in
78
    Var.Map.replace v (new_i, new_s) map
79

80 81 82 83
  let join map1 map2 = Var.Map.fold add map1 map2
  let fold = Var.Map.fold
  let empty = Var.Map.empty
  let for_all f m = List.for_all (fun (k,v) -> f k v) (Var.Map.get m)
84
end
85

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
module ConstrSet =
struct

    (* A set of constraint-sets is just a list of Lines,
       that are pairwise "non-subsumable"
    *)
  type t = Line.t list
  let elements t = t
  let empty = []

  let add m l =
    let rec loop m l acc =
      match l with
        [] -> m :: acc
      | mm :: ll ->
         if Line.subsumes m mm then List.rev_append ll (m::acc)
         else if Line.subsumes mm m then List.rev_append ll (mm::acc)
         else loop m ll (mm::acc)
    in
    loop m l []

  let unsat = empty
  let sat = [ Line.empty ]

  let is_unsat m = m == []
  let is_sat m = match m with
      [ l ] when Line.is_empty l -> true
    | _ -> false

  let print ppf s = Utils.pp_list ~delim:("{","}") Line.print ppf s

  let fold f l a = List.fold_left (fun e a -> f a e) a l

  let is_empty l = l == []

  (* Square union : *)
  let union s1 s2 =
    match s1, s2 with
      [], _ -> s2
    | _, [] -> s1
    | _ ->
         (* Invariant: all elements of s1 (resp s2) are pairwise
            incomparable (they don't subsume one another)
            let e1 be an element of s1:
            - if e1 subsumes no element of s2, add e1 to the result
            - if e1 subsumes an element e2 of s2, add e1 to the
            result and remove e2 from s2
            - if an element e2 of s2 subsumes e1, add e2 to the
            result and remove e2 from s2 (and discard e1)

            once we are done for all e1, add the remaining elements from
            s2 to the result.
         *)

       let append e1 s2 result =
         let rec loop s2 accs2 =
           match s2 with
             [] -> accs2, e1::result
           | e2 :: ss2 ->
              if Line.subsumes e1 e2 then List.rev_append ss2 accs2, e1::result
              else if Line.subsumes e2 e1 then List.rev_append ss2 accs2, e2::result
              else loop ss2 (e2::accs2)
148
         in
149 150 151 152 153 154 155 156 157 158 159 160 161
         loop s2 []
       in
       let rec loop s1 s2 result =
         match s1 with
           [] -> List.rev_append s2 result
         | e1 :: ss1 ->
            let new_s2, new_result = append e1 s2 result in
            loop ss1 new_s2 new_result
       in
       loop s1 s2 []

    (* Square intersection *)
    let inter s1 s2 =
162 163 164
      match s1,s2 with
        [], _ | _, [] -> []
      | _ ->
165 166 167 168 169 170 171 172
         (* Perform the cartesian product. For each constraint m1 in s1,
            m2 in s2, we add Line.join m1 m2 to the result.
            Optimisations:
            - we use add to ensure that we do not add something that subsumes
            a constraint set that is already in the result
            - if m1 subsumes m2, it means that whenever m2 holds, so does m1, so
            we only add m2 (note that the condition is reversed w.r.t. union).
         *)
173 174
         fold (fun m1 acc1 ->
           fold (fun m2 acc2 ->
175 176
             let merged = if Line.subsumes m1 m2 then m2
               else if Line.subsumes m2 m1 then m1
177
               else Line.join m1 m2
178 179 180 181 182 183
             in
             add merged acc2
           )
             s2 acc1) s1 []
    let filter = List.filter

184 185
    let singleton v cs = [ (Line.singleton v cs) ]

186
end
187 188


189 190 191 192 193 194 195 196 197
let norm_simple (type a) (module V : VarType with type Atom.t = a) t =
  if V.(Atom.is_empty (leafconj (proj t))) then ConstrSet.sat
  else ConstrSet.unsat

let norm_atoms _ _ t = norm_simple (module VarAtoms) t
let norm_chars _ _ t = norm_simple (module VarChars) t
let norm_ints _ _ t = norm_simple (module VarIntervals) t
let norm_abstract _ _ t = norm_simple (module VarAbstracts) t

198

199 200 201 202 203 204 205 206
let single (type a) (module V : VarType with type Atom.t = a) b v lpos lneg =
  let aux dir l =
    List.fold_left (fun acc va ->
      cap acc (dir (
	match va with
	  `Var v -> var v
	| (`Atm _) as a -> V.(inj (atom a)))))
      any l
207 208
  in
  let id = (fun x -> x) in
209 210
  let t = cap (aux id lpos) (aux neg lneg) in
  if b then neg t else t
211 212


213 214
let toplevel
    (type a) (module V : VarType with type Atom.t = a)
215 216
    delta norm_rec mem lpos lneg
    =
217 218 219 220 221 222 223 224
  let singleton x ((t, s) as cst) =
    (* constraints over monomorphic variables must be trivial,
       that is true for all substitutions
    *)
    let vx = var x in
    if Var.Set.mem delta x && (not (subtype t vx) || not (subtype vx s))
    then ConstrSet.unsat
    else ConstrSet.singleton x cst
225 226 227 228
  in
  match lpos, lneg with
    [], (`Var x)::neg ->
      let t = single (module V) false x [] neg in
229
      singleton x (t, any)
230

231
  | (`Var x)::pos, [] ->
232
     let t = single (module V) true x pos [] in
233
     singleton x (empty, t)
234

235
  | (`Var x)::pos, (`Var y)::neg ->
236
     if Var.compare x y < 0 then
237
       let t = single (module V) true x pos lneg in
238
       singleton x (empty, t)
239
     else
240
       let t = single (module V) false y lpos neg in
241
       singleton y (t, any)
242

243 244
  | [`Atm _ ], (`Var x)::neg ->
     let t = single (module V) false x lpos neg in
245
     singleton x (t, any)
246

247 248
  | [ (`Atm _) as a ],  [ ] -> norm_rec delta mem V.(inj (atom a))
  | _ -> assert false
249

250 251

let big_prod f acc l =
252
  List.fold_left (fun acc (pos,neg) ->
253
    ConstrSet.inter  acc (f pos neg)
254 255 256 257 258 259
  ) acc l

module NormMemoHash = Hashtbl.Make(Custom.Pair(Descr)(Var.Set))

let memo_norm = NormMemoHash.create 17

260
let rec norm delta mem t =
261 262 263 264 265
  DEBUG normrec (
    Format.eprintf
      " @[Entering norm rec(%a):@\n" Print.pp_type t);
  let res =
    try
266 267
      (* If we find it in the global hashtable,
	 we are finished *)
268 269 270 271 272 273 274 275 276 277
      let res = NormMemoHash.find memo_norm (t, delta) in
      DEBUG normrec (Format.eprintf
                       "@[ - Result found in global table @]@\n");
      res
    with
      Not_found ->
        try
          let finished, cst = NormMemoHash.find mem (t, delta) in
          DEBUG normrec (Format.eprintf
                           "@[ - Result found in local table, finished = %b @]@\n" finished);
278
          if finished then cst else ConstrSet.sat
279 280 281 282 283 284 285
        with
          Not_found ->
            begin
              let res =
                  (* base cases *)
                if is_empty t then begin
                  DEBUG normrec (Format.eprintf "@[ - Empty type case @]@\n");
286
                  ConstrSet.sat
287 288
                end else if no_var t then begin
                  DEBUG normrec (Format.eprintf "@[ - No var case @]@\n");
289
                  ConstrSet.unsat
290 291 292 293
                end else if is_var t then begin
                  let (v,p) = Variable.extract t in
                  if Var.Set.mem delta v then begin
                    DEBUG normrec (Format.eprintf "@[ - Monomorphic var case @]@\n");
294
                    ConstrSet.unsat (* if it is monomorphic, unsat *)
295 296
                  end else begin
                    DEBUG normrec (Format.eprintf "@[ - Polymorphic var case @]@\n");
297
                    (* otherwise, create a single constraint according to its polarity *)
298
		    ConstrSet.singleton v (if p then (empty, empty) else (any, any))
299 300 301
                  end
                end else begin (* type is not empty and is not a variable *)
                  DEBUG normrec (Format.eprintf "@[ - Inductive case:@\n");
302
                  let mem = NormMemoHash.add mem (t,delta) (false, ConstrSet.sat); mem in
303
                  let t = Iter.simplify t in
304 305 306 307 308 309 310 311 312 313 314
                  let aux (module V : VarType) t acc =
                    let res =
		      big_prod
			(toplevel (module V) delta (norm_dispatch V.kind) mem)
			acc
			V.(get (proj t))
		    in
	            DEBUG normrec
		      (Format.eprintf "@[ - After %s constraints: %a @]@\n"
			 pp_type_kind V.kind ConstrSet.print res);
		    res
315
                  in
316
		  let acc = Iter.fold aux t ConstrSet.sat in
317
                  DEBUG normrec (Format.eprintf "@]@\n");
318 319
		  acc

320 321 322 323 324 325 326 327
                end
              in
              NormMemoHash.replace mem (t, delta) (true,res); res
            end
  in
  DEBUG normrec (Format.eprintf
                   "Leaving norm rec(%a) = %a@]@\n%!"
                   Print.pp_type t
328
                   ConstrSet.print res
329 330 331
  );
  res

332 333 334 335 336 337 338 339 340 341
and norm_dispatch k =
  match k with
    `intervals -> norm_ints
  | `chars -> norm_chars
  | `atoms -> norm_atoms
  | `abstracts -> norm_abstract
  | `times -> norm_pair (module VarTimes : VarType with type Atom.t = Pair.t)
  | `xml -> norm_pair (module VarXml : VarType with type Atom.t = Pair.t)
  | `arrow -> norm_arrow
  | `record -> norm_rec
342 343 344 345 346 347 348 349 350 351
  (* (t1,t2) = intersection of all (fst pos,snd pos) \in P
   * (s1,s2) \in N
   * [t1] v [t2] v ( [t1 \ s1] ^ [t2 \ s2] ^
   * [t1 \ s1 \ s1_1] ^ [t2 \ s2 \ s2_1 ] ^
   * [t1 \ s1 \ s1_1 \ s1_2] ^ [t2 \ s2 \ s2_1 \ s2_2 ] ^ ... )
   *
   * prod(p,n) = [t1] v [t2] v prod'(t1,t2,n)
   * prod'(t1,t2,{s1,s2} v n) = ([t1\s1] v prod'(t1\s1,t2,n)) ^
   *                            ([t2\s2] v prod'(t1,t2\s2,n))
   * *)
352

353
and norm_pair (module V : VarType with type Atom.t = Pair.t) delta mem t =
354 355
  let norm_prod pos neg =
    let rec aux t1 t2 = function
356
      |[] -> ConstrSet.unsat
357 358 359
      |(s1,s2) :: rest -> begin
        let z1 = diff t1 (descr s1) in
        let z2 = diff t2 (descr s2) in
360 361
        let con1 = norm delta mem z1 in
        let con2 = norm delta mem z2 in
362 363
        let con10 = aux z1 t2 rest in
        let con20 = aux t1 z2 rest in
364 365 366
        let con11 = ConstrSet.union con1 con10 in
        let con22 = ConstrSet.union con2 con20 in
        ConstrSet.inter con11 con22
367 368 369 370
      end
    in
      (* cap_product return the intersection of all (fst pos,snd pos) *)
    let (t1,t2) = cap_product any any pos in
371 372
    let con1 = norm delta mem t1 in
    let con2 = norm delta mem t2 in
373
    let con0 = aux t1 t2 neg in
374
    ConstrSet.(union (union con1 con2) con0)
375
  in
376
  let t = V.leafconj (V.proj t) in
377
  big_prod norm_prod ConstrSet.sat (Pair.get t)
378

379 380
and norm_rec delta mem t =
  let norm_rec_array ((oleft,left),rights) =
381
    let rec aux accus seen = function
382
      | [ ] -> ConstrSet.sat
383 384 385 386 387 388
      |(false,_) :: rest when oleft -> aux accus seen rest
      |(b,field) :: rest ->
         let right = seen @ rest in
         snd (Array.fold_left (fun (i,acc) t ->
           let di = diff accus.(i) t in
           let accus' = Array.copy accus in accus'.(i) <- di ;
389 390
           (i+1,ConstrSet.inter acc (ConstrSet.inter (norm delta mem di) (aux accus' [] right)))
         ) (0,ConstrSet.sat) field
391 392
         )
    in
393 394
    let c = Array.fold_left (fun acc t -> ConstrSet.union (norm delta mem t) acc) ConstrSet.sat left in
    ConstrSet.inter (aux left [] rights) c
395
  in
396
  let t = VarRec.leafconj (VarRec.proj t) in
397
  List.fold_left (fun acc (_,p,n) ->
398
    if ConstrSet.is_unsat acc then acc else ConstrSet.inter acc (norm_rec_array (p,n))
399
  ) ConstrSet.sat (get_record t)
400
    
401 402 403 404 405 406 407 408 409 410 411 412 413
  (* arrow(p,{t1 -> t2}) = [t1] v arrow'(t1,any \ t2,p)
   * arrow'(t1,acc,{s1 -> s2} v p) =
   * ([t1\s1] v arrow'(t1\s1,acc,p)) ^
   * ([acc ^ {s2}] v arrow'(t1,acc ^ {s2},p))

   * t1 -> t2 \ s1 -> s2 =
   * [t1] v (([t1\s1] v {[]}) ^ ([t2\s2] v {[]}))

   * Bool -> Bool \ Int -> A =
   * [Bool] v (([Bool\Int] v {[]}) ^ ([Bool\A] v {[]})

   * P(Q v {a}) = {a} v P(Q) v {X v {a} | X \in P(Q) }
   *)
414
and norm_arrow delta mem t =
415 416
  let rec norm_arrow pos neg =
    match neg with
417
    |[] -> ConstrSet.unsat
418
    |(t1,t2) :: n ->
419 420 421
       let t1 = descr t1 and t2 = descr t2 in
       let con1 = norm delta mem t1 in (* [t1] *)
       let con2 = aux t1 (diff any t2) pos in
422
       let con0 = norm_arrow pos n in
423
       ConstrSet.union (ConstrSet.union con1 con2) con0
424
  and aux t1 acc = function
425
    |[] -> ConstrSet.unsat
426 427 428
    |(s1,s2) :: p ->
       let t1s1 = diff t1 (descr s1) in
       let acc1 = cap acc (descr s2) in
429 430
       let con1 = norm delta mem t1s1 in (* [t1 \ s1] *)
       let con2 = norm delta mem acc1 in (* [(Any \ t2) ^ s2] *)
431 432
       let con10 = aux t1s1 acc p in
       let con20 = aux t1 acc1 p in
433 434 435
       let con11 = ConstrSet.union con1 con10 in
       let con22 = ConstrSet.union con2 con20 in
       ConstrSet.inter con11 con22
436
  in
437
  let t = VarArrow.leafconj (VarArrow.proj t) in
438
  big_prod norm_arrow ConstrSet.sat (Pair.get t)
439 440 441 442 443 444



let norm delta t =
  try NormMemoHash.find memo_norm (t,delta)
  with Not_found -> begin
445
    let res = norm delta (NormMemoHash.create 17) t in
446 447 448 449
    NormMemoHash.add memo_norm (t,delta) res; res
  end

  (* merge needs delta because it calls norm recursively *)
450
let rec merge delta cache m =
451
  let res =
452 453
    Line.fold (fun v (inf, sup) acc ->
      (* no need to add new constraints *)
454 455 456 457 458 459 460
      if subtype inf sup  then acc
      else
        let x = diff inf sup in
        if Cache.lookup x cache != None then acc
        else
          let cache, _ = Cache.find ignore x cache in
          let n = norm delta x in
461 462
          if ConstrSet.is_unsat n then raise (UnSatConstr "merge2");
          let c1 = ConstrSet.inter (ConstrSet.(add m empty)) n
463 464
          in
          let c2 =
465
            ConstrSet.fold
466
              (fun m1 acc ->
467 468
                ConstrSet.union acc (merge delta cache m1))
              c1 ConstrSet.empty
469
          in
470 471
          ConstrSet.union c2 acc
    ) m ConstrSet.empty
472
  in
473 474
  if ConstrSet.is_unsat res then ConstrSet.(add m empty)
  else res
475

476
let merge delta m = merge delta Cache.emp m
477

478 479 480 481 482
let solve delta s =
  let add_eq alpha s t acc =
    let beta =
      var Var.(mk ~internal:true ("#" ^ (ident alpha)))
    in
483
    Var.Map.replace alpha (cap (cup s beta) t) acc
484 485 486 487 488 489 490 491 492 493 494 495 496
  in
  let extra_var t acc =
    if is_var t then
      let v, _ = Variable.extract t in
      if Var.Set.mem delta v then acc
      else add_eq v empty any acc
    else acc
  in
  let to_eq_set m =
    Line.fold (fun alpha (s,t) acc ->
      let acc = extra_var t acc in
      let acc = extra_var s acc in
      add_eq alpha s t acc
497
    ) m Var.Map.empty
498 499
  in
  ConstrSet.fold (fun m acc -> (to_eq_set m) :: acc) s []
500 501

let solve delta s =
502
  let res = solve delta s  in
503 504
  DEBUG solve (Format.eprintf "Calling solve (%a, %a), got %a@\n%!"
		 Var.Set.print delta ConstrSet.print s
505 506
		 pp_sl res);
  res
507

508 509 510
let () = Format.pp_set_margin Format.err_formatter 200

let unify (eq_set : Descr.t Var.Map.map) =
511
  let rec loop eq_set accu =
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
    DEBUG unify
      (Format.eprintf "@[<hov 2>  Entering unify_loop @[(%a, %a)@]@\n"
	 Subst.print eq_set Subst.print accu);
    let res =
      if Var.Map.is_empty eq_set then accu else
	begin
	  let (alpha, t), eq_set' = Var.Map.remove_min eq_set in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Choosing equation @[%a = %a@]@]@\n"
			 Var.print alpha Types.Print.pp_type t);
	  let x = Subst.solve_rectype t alpha in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Solving equation @[%a = %a@]@]@\n"
			 Var.print alpha Types.Print.pp_type x);
	  let eq_set' =
            Var.Map.map (fun s -> Subst.single s (alpha,x)) eq_set'
	  in
	  let sigma = loop eq_set' (Var.Map.replace alpha x accu) in
	  let t_alpha = Subst.full x sigma in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Applying substitution @[%a@] to @[%a@], @[<hov 4>got %a@]@]@\n"
			 Subst.print sigma Print.pp_type x Print.pp_type t_alpha);
	  Var.Map.replace alpha t_alpha sigma
	end
    in
    DEBUG unify (Format.eprintf "Leaving unify_loop @[(%a, %a)@], @[<hov 4>got %a@]@]@\n"
		   Subst.print eq_set Subst.print accu Subst.print res);
    res
540
  in
541 542
  let res = loop eq_set Var.Map.empty in
  DEBUG unify (Format.(pp_print_flush err_formatter ()));
543
  res
544

545

546 547 548 549 550 551 552
exception Step1Fail
exception Step2Fail

let tallying delta l =
  let n =
    List.fold_left (fun acc (s,t) ->
      let d = diff s t in
553 554 555 556
      if is_empty d then ConstrSet.sat
      else if no_var d then ConstrSet.unsat
      else ConstrSet.inter acc (norm delta d)
    ) ConstrSet.sat l
557
  in
558
  if ConstrSet.is_unsat n then raise Step1Fail else
559
    let m =
560 561 562
      ConstrSet.fold (fun c acc ->
        try
	  (solve delta (merge delta c)) @ acc
563
        with UnSatConstr _ -> acc
564
      ) n []
565
    in
566 567
    if m == [] then raise Step2Fail;
    List.map unify m
568 569


570
type symsubst = I
571
		| S of Descr.t Var.Map.map
572
		| A of (symsubst * symsubst)
573 574 575

let rec dom = function
  |I -> Var.Set.empty
576
  |S si -> Var.Map.domain si
577 578
  |A (si,sj) -> Var.Set.cup (dom si) (dom sj)

579
(* composition of two symbolic substitution sets sigmaI, sigmaJ .Cartesian product *)
580 581 582 583 584 585 586 587
let (++) sI sJ =
  let bind m f = List.flatten(List.map f m) in
  bind sI (fun si ->
    bind sJ (fun sj ->
      [A(si,sj)]
    )
  )

588
let (>>) t s =
589 590
  Var.Map.fold (fun v t acc ->
    Subst.single acc (v, t)) s t
591 592

(* apply a symbolic substitution si to a type t *)
593 594 595 596
let (@@) t si =
  let vst = ref Var.Set.empty in
  let vsi = ref Var.Set.empty in
  let filter t si =
597
    vsi := dom (S si);
598 599 600 601
    vst := all_vars t;
    not(Var.Set.is_empty (Var.Set.cap !vst !vsi))
  in
  let filterdiff t si sj =
602
    let vsj = dom (S sj) in
603 604 605 606 607 608 609 610 611 612 613 614
    not(Var.Set.is_empty (Var.Set.cap !vst (Var.Set.diff !vsi vsj)))
  in
  let rec aux t = function
    |I -> t
    |S si -> t >> si
    |A (S si,_) when filter t si -> t >> si
    |A (S si,S sj) when filterdiff t si sj -> (t >> sj) >> si
    |A (si,sj) -> aux (aux t sj) si
  in
  aux t si

let domain sl =
615
  List.fold_left (fun acc si -> Var.Set.cup acc (Var.Map.domain si)) Var.Set.empty sl
616 617 618

let codomain ll =
  List.fold_left (fun acc e ->
619
    Var.Map.fold (fun _ v acc ->
620 621 622 623
      Var.Set.cup (all_vars v) acc
    ) e acc
  ) Var.Set.empty ll

624 625
let is_identity = List.for_all (Var.Map.is_empty)
let identity = [Var.Map.empty]
626 627 628 629

let filter f sl =
  if is_identity sl then sl else
    List.fold_left (fun acc si ->
630 631
      let e = Var.Map.filter (fun v _ -> f v) si in
      if Var.Map.is_empty e then acc else e::acc
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
    ) [] sl



let set a i v =
  let len = Array.length !a in
  if i <  len then (!a).(i) <- v
  else begin
    let b = Array.make (2*len+1) empty in
    Array.blit !a 0 b 0 len;
    b.(i) <- v;
    a := b
  end
let get a i = if i < 0 then any else (!a).(i)

647
exception FoundSquareSub of Descr.t Var.Map.map list
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664

let squaresubtype delta s t =
  NormMemoHash.clear memo_norm;
  let ai = ref [| |] in
  let tallying i =
    try
      let s = get ai i in
      let sl = tallying delta [ (s,t) ] in
      raise (FoundSquareSub sl)
    with
      Step1Fail -> (assert (i == 0); raise (UnSatConstr "apply_raw step1"))
    | Step2Fail -> () (* continue *)
  in
  let rec loop i =
    try
      let ss =
        if i = 0 then s
665
        else (cap (Subst.freshen delta s) (get ai (i-1)))
666 667 668 669 670 671 672 673 674 675 676
      in
      set ai i ss;
      tallying i;
      loop (i+1)
    with FoundSquareSub sl -> sl
  in
  loop 0

let is_squaresubtype delta s t =
  try ignore(squaresubtype delta s t);true with UnSatConstr _ -> false

677
exception FoundApply of t * int * int * Descr.t Var.Map.map list
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701

(** find two sets of type substitutions I,J such that
    s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *)

let apply_raw delta s t =
  DEBUG apply_raw (Format.eprintf " @[Entering apply_raw (delta:@[%a@], @[%a@], @[%a@])@\n%!"
                     Var.Set.pp delta
                     Print.pp_type s
                     Print.pp_type t
  );

  NormMemoHash.clear memo_norm;
  let vgamma = Var.mk "Gamma" in
  let gamma = var vgamma in
  let cgamma = cons gamma in
  (* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *)
  let ai = ref [| |]
  and aj = ref [| |] in
  let tallying i j  =
    try
      let s = get ai i in
      let t = arrow (cons (get aj j)) cgamma in
      let sl = tallying delta [ (s,t) ] in
      let new_res =
702
        Subst.clean_type delta (
703
          List.fold_left (fun tacc si ->
704
            cap tacc (Subst.full gamma  si)
705 706 707 708 709 710 711 712 713 714 715 716 717
          ) any sl
        )
      in
      raise (FoundApply(new_res,i,j,sl))
    with
      Step1Fail -> (assert (i == 0 &&  j == 0); raise (UnSatConstr "apply_raw step1"))
    | Step2Fail -> () (* continue *)
  in
  let rec loop i =
    try
      (* Format.printf "Starting expansion %i @\n@." i; *)
      let (ss,tt) =
        if i = 0 then (s,t) else
718 719
          ((cap (Subst.freshen delta s) (get ai (i-1))),
           (cap (Subst.freshen delta t) (get aj (i-1))))
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
      in
      set ai i ss;
      set aj i tt;
      for j = 0 to i-1 do
        tallying j i;
        tallying i j;
      done;
      tallying i i;
      loop (i+1)
    with FoundApply (res, i, j, sl) ->
      DEBUG apply_raw (Format.eprintf " Leaving apply_raw (delta:@[%a@], @[%a@], @[%a@]) = @[%a@], @[%a@] @]@\n%!"
                         Var.Set.pp delta
                         Print.pp_type s
                         Print.pp_type t
                         Print.pp_type res
735
                         ConstrSet.printl sl
736 737 738 739 740 741 742 743 744 745
      );
      (sl, get ai i, get aj j, res)
  in
  loop 0

let apply_full delta s t =
  let _,_,_,res = apply_raw delta s t in
  res

let squareapply delta s t = let s,_,_,res = apply_raw delta s t in (s,res)