type_tallying.ml 22.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
open Types

let cap_t d t = cap d (descr t)

let cap_product any_left any_right l =
  List.fold_left
    (fun (d1,d2) (t1,t2) -> (cap_t d1 t1, cap_t d2 t2))
    (any_left,any_right)
    l

exception UnSatConstr of string

13
let pp_sl ppf l =
14
  Utils.pp_list ~delim:("{","}") Subst.print ppf l
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

type constr = Types.t * Types.t (* lower and
				   upper bounds *)

    (* A comparison function between types that
       is compatible with subtyping. If types are
       not in a subtyping relation, use implementation
       defined order
    *)

let compare_type t1 t2 =
  let inf12 = Types.subtype t1 t2 in
  let inf21 = Types.subtype t2 t1 in
  if inf12 && inf21 then 0
  else if inf12 then -1 else if inf21 then 1 else
      let c = Types.compare t1 t2 in
      assert (c <> 0);
      c


(* A line is a conjunction of constraints *)
36

37 38
module Line = struct

39
  type t = constr Var.Map.map
40

41 42 43
  let singleton = Var.Map.singleton
  let is_empty = Var.Map.is_empty
  let length = Var.Map.length
44 45 46 47 48 49 50 51 52

    (* a set of constraints m1 subsumes a set of constraints m2,
       that is the solutions for m1 contains all the solutions for
       m2 if:
       forall i1 <= v <= s1 in m1,
       there exists i2 <= v <= s2 in m2 such that i1 <= i2 <= v <= s2 <= s1
    *)
  let subsumes map1 map2 =
    List.for_all (fun (v,(i1, s1)) ->
53
      try let i2, s2 = Var.Map.assoc v map2 in
54 55
          subtype i1 i2 && subtype s2 s1
      with Not_found -> false
56
    ) (Var.Map.get map1)
57 58 59 60

  let print ppf map =
    Utils.pp_list ~delim:("{","}") (fun ppf (v, (i,s)) ->
      Format.fprintf ppf "%a <= %a <= %a" Print.pp_type i Var.print v Print.pp_type s
61
    ) ppf (Var.Map.get map)
62 63

  let compare map1 map2 =
64
    Var.Map.compare (fun (i1,s1) (i2,s2) ->
65 66 67 68 69 70 71
      let c = compare_type i1 i2 in
      if c == 0 then compare_type s1 s2
      else c) map1 map2

  let add v (inf, sup) map =
    let new_i, new_s =
      try
72
        let old_i, old_s = Var.Map.assoc v map in
73 74 75 76 77
        cup old_i inf,
        cap old_s sup
      with
        Not_found -> inf, sup
    in
78
    Var.Map.replace v (new_i, new_s) map
79

80 81 82 83
  let join map1 map2 = Var.Map.fold add map1 map2
  let fold = Var.Map.fold
  let empty = Var.Map.empty
  let for_all f m = List.for_all (fun (k,v) -> f k v) (Var.Map.get m)
84
end
85

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
module ConstrSet =
struct

    (* A set of constraint-sets is just a list of Lines,
       that are pairwise "non-subsumable"
    *)
  type t = Line.t list
  let elements t = t
  let empty = []

  let add m l =
    let rec loop m l acc =
      match l with
        [] -> m :: acc
      | mm :: ll ->
         if Line.subsumes m mm then List.rev_append ll (m::acc)
         else if Line.subsumes mm m then List.rev_append ll (mm::acc)
         else loop m ll (mm::acc)
    in
    loop m l []

  let unsat = empty
  let sat = [ Line.empty ]

  let is_unsat m = m == []
  let is_sat m = match m with
      [ l ] when Line.is_empty l -> true
    | _ -> false

  let print ppf s = Utils.pp_list ~delim:("{","}") Line.print ppf s

  let fold f l a = List.fold_left (fun e a -> f a e) a l

  let is_empty l = l == []

  (* Square union : *)
  let union s1 s2 =
    match s1, s2 with
      [], _ -> s2
    | _, [] -> s1
    | _ ->
         (* Invariant: all elements of s1 (resp s2) are pairwise
            incomparable (they don't subsume one another)
            let e1 be an element of s1:
            - if e1 subsumes no element of s2, add e1 to the result
            - if e1 subsumes an element e2 of s2, add e1 to the
            result and remove e2 from s2
            - if an element e2 of s2 subsumes e1, add e2 to the
            result and remove e2 from s2 (and discard e1)

            once we are done for all e1, add the remaining elements from
            s2 to the result.
         *)

       let append e1 s2 result =
         let rec loop s2 accs2 =
           match s2 with
             [] -> accs2, e1::result
           | e2 :: ss2 ->
              if Line.subsumes e1 e2 then List.rev_append ss2 accs2, e1::result
              else if Line.subsumes e2 e1 then List.rev_append ss2 accs2, e2::result
              else loop ss2 (e2::accs2)
148
         in
149 150 151 152 153 154 155 156 157 158 159 160 161
         loop s2 []
       in
       let rec loop s1 s2 result =
         match s1 with
           [] -> List.rev_append s2 result
         | e1 :: ss1 ->
            let new_s2, new_result = append e1 s2 result in
            loop ss1 new_s2 new_result
       in
       loop s1 s2 []

    (* Square intersection *)
    let inter s1 s2 =
162 163 164
      match s1,s2 with
        [], _ | _, [] -> []
      | _ ->
165 166 167 168 169 170 171 172
         (* Perform the cartesian product. For each constraint m1 in s1,
            m2 in s2, we add Line.join m1 m2 to the result.
            Optimisations:
            - we use add to ensure that we do not add something that subsumes
            a constraint set that is already in the result
            - if m1 subsumes m2, it means that whenever m2 holds, so does m1, so
            we only add m2 (note that the condition is reversed w.r.t. union).
         *)
173 174
         fold (fun m1 acc1 ->
           fold (fun m2 acc2 ->
175 176
             let merged = if Line.subsumes m1 m2 then m2
               else if Line.subsumes m2 m1 then m1
177
               else Line.join m1 m2
178 179 180 181 182 183
             in
             add merged acc2
           )
             s2 acc1) s1 []
    let filter = List.filter

184 185
    let singleton v cs = [ (Line.singleton v cs) ]

186
end
187 188


189 190 191 192 193 194 195 196 197
let norm_simple (type a) (module V : VarType with type Atom.t = a) t =
  if V.(Atom.is_empty (leafconj (proj t))) then ConstrSet.sat
  else ConstrSet.unsat

let norm_atoms _ _ t = norm_simple (module VarAtoms) t
let norm_chars _ _ t = norm_simple (module VarChars) t
let norm_ints _ _ t = norm_simple (module VarIntervals) t
let norm_abstract _ _ t = norm_simple (module VarAbstracts) t

198

199 200 201 202 203 204 205 206
let single (type a) (module V : VarType with type Atom.t = a) b v lpos lneg =
  let aux dir l =
    List.fold_left (fun acc va ->
      cap acc (dir (
	match va with
	  `Var v -> var v
	| (`Atm _) as a -> V.(inj (atom a)))))
      any l
207 208
  in
  let id = (fun x -> x) in
209 210
  let t = cap (aux id lpos) (aux neg lneg) in
  if b then neg t else t
211 212


213 214
let toplevel
    (type a) (module V : VarType with type Atom.t = a)
215 216
    delta norm_rec mem lpos lneg
    =
217
  let singleton x ((t, s) as cst) =
218 219 220 221 222
    if Var.Set.mem delta x then
      let vx = var x in
      if subtype t vx && subtype vx s then
	ConstrSet.sat
      else ConstrSet.unsat
223
    else ConstrSet.singleton x cst
224
  in
225 226 227 228 229 230 231 232
  let var_compare v1 v2 =
    let mono1 = Var.Set.mem delta v1 in
    let mono2 = Var.Set.mem delta v2 in
    match mono1, mono2 with
    | false, true -> -1
    | true, false -> 1
    | _ -> Var.compare v1 v2
  in
233 234 235
  match lpos, lneg with
    [], (`Var x)::neg ->
      let t = single (module V) false x [] neg in
236
      singleton x (t, any)
237

238
  | (`Var x)::pos, [] ->
239
     let t = single (module V) true x pos [] in
240
     singleton x (empty, t)
241

242
  | (`Var x)::pos, (`Var y)::neg ->
243
     if var_compare x y < 0 then
244
       let t = single (module V) true x pos lneg in
245
       singleton x (empty, t)
246
     else
247
       let t = single (module V) false y lpos neg in
248
       singleton y (t, any)
249

250 251
  | [`Atm _ ], (`Var x)::neg ->
     let t = single (module V) false x lpos neg in
252
     singleton x (t, any)
253

254 255
  | [ (`Atm _) as a ],  [ ] -> norm_rec delta mem V.(inj (atom a))
  | _ -> assert false
256

257 258

let big_prod f acc l =
259
  List.fold_left (fun acc (pos,neg) ->
260
    ConstrSet.inter  acc (f pos neg)
261 262 263 264 265 266
  ) acc l

module NormMemoHash = Hashtbl.Make(Custom.Pair(Descr)(Var.Set))

let memo_norm = NormMemoHash.create 17

267
let rec norm delta mem t =
268 269 270 271 272
  DEBUG normrec (
    Format.eprintf
      " @[Entering norm rec(%a):@\n" Print.pp_type t);
  let res =
    try
273 274
      (* If we find it in the global hashtable,
	 we are finished *)
275 276 277 278 279 280 281 282 283 284
      let res = NormMemoHash.find memo_norm (t, delta) in
      DEBUG normrec (Format.eprintf
                       "@[ - Result found in global table @]@\n");
      res
    with
      Not_found ->
        try
          let finished, cst = NormMemoHash.find mem (t, delta) in
          DEBUG normrec (Format.eprintf
                           "@[ - Result found in local table, finished = %b @]@\n" finished);
285
          if finished then cst else ConstrSet.sat
286 287 288 289 290 291 292
        with
          Not_found ->
            begin
              let res =
                  (* base cases *)
                if is_empty t then begin
                  DEBUG normrec (Format.eprintf "@[ - Empty type case @]@\n");
293
                  ConstrSet.sat
294 295
                end else if no_var t then begin
                  DEBUG normrec (Format.eprintf "@[ - No var case @]@\n");
296
                  ConstrSet.unsat
297 298 299 300
                end else if is_var t then begin
                  let (v,p) = Variable.extract t in
                  if Var.Set.mem delta v then begin
                    DEBUG normrec (Format.eprintf "@[ - Monomorphic var case @]@\n");
301
                    ConstrSet.unsat (* if it is monomorphic, unsat *)
302 303
                  end else begin
                    DEBUG normrec (Format.eprintf "@[ - Polymorphic var case @]@\n");
304
                    (* otherwise, create a single constraint according to its polarity *)
305
		    ConstrSet.singleton v (if p then (empty, empty) else (any, any))
306 307 308
                  end
                end else begin (* type is not empty and is not a variable *)
                  DEBUG normrec (Format.eprintf "@[ - Inductive case:@\n");
309
                  let mem = NormMemoHash.add mem (t,delta) (false, ConstrSet.sat); mem in
310
                  let t = Iter.simplify t in
311 312 313 314 315 316 317 318
                  let aux (module V : VarType) t acc =
                    let res =
		      big_prod
			(toplevel (module V) delta (norm_dispatch V.kind) mem)
			acc
			V.(get (proj t))
		    in
	            DEBUG normrec
319
		      (Format.eprintf "@[ - After %a constraints: %a @]@\n"
320 321
			 pp_type_kind V.kind ConstrSet.print res);
		    res
322
                  in
323
		  let acc = Iter.fold aux t ConstrSet.sat in
324
                  DEBUG normrec (Format.eprintf "@]@\n");
325 326
		  acc

327 328 329 330 331 332 333 334
                end
              in
              NormMemoHash.replace mem (t, delta) (true,res); res
            end
  in
  DEBUG normrec (Format.eprintf
                   "Leaving norm rec(%a) = %a@]@\n%!"
                   Print.pp_type t
335
                   ConstrSet.print res
336 337 338
  );
  res

339 340 341 342 343 344 345 346 347 348
and norm_dispatch k =
  match k with
    `intervals -> norm_ints
  | `chars -> norm_chars
  | `atoms -> norm_atoms
  | `abstracts -> norm_abstract
  | `times -> norm_pair (module VarTimes : VarType with type Atom.t = Pair.t)
  | `xml -> norm_pair (module VarXml : VarType with type Atom.t = Pair.t)
  | `arrow -> norm_arrow
  | `record -> norm_rec
349 350 351 352 353 354 355 356 357 358
  (* (t1,t2) = intersection of all (fst pos,snd pos) \in P
   * (s1,s2) \in N
   * [t1] v [t2] v ( [t1 \ s1] ^ [t2 \ s2] ^
   * [t1 \ s1 \ s1_1] ^ [t2 \ s2 \ s2_1 ] ^
   * [t1 \ s1 \ s1_1 \ s1_2] ^ [t2 \ s2 \ s2_1 \ s2_2 ] ^ ... )
   *
   * prod(p,n) = [t1] v [t2] v prod'(t1,t2,n)
   * prod'(t1,t2,{s1,s2} v n) = ([t1\s1] v prod'(t1\s1,t2,n)) ^
   *                            ([t2\s2] v prod'(t1,t2\s2,n))
   * *)
359

360
and norm_pair (module V : VarType with type Atom.t = Pair.t) delta mem t =
361 362
  let norm_prod pos neg =
    let rec aux t1 t2 = function
363
      |[] -> ConstrSet.unsat
364 365 366
      |(s1,s2) :: rest -> begin
        let z1 = diff t1 (descr s1) in
        let z2 = diff t2 (descr s2) in
367 368
        let con1 = norm delta mem z1 in
        let con2 = norm delta mem z2 in
369 370
        let con10 = aux z1 t2 rest in
        let con20 = aux t1 z2 rest in
371 372 373
        let con11 = ConstrSet.union con1 con10 in
        let con22 = ConstrSet.union con2 con20 in
        ConstrSet.inter con11 con22
374 375 376 377
      end
    in
      (* cap_product return the intersection of all (fst pos,snd pos) *)
    let (t1,t2) = cap_product any any pos in
378 379
    let con1 = norm delta mem t1 in
    let con2 = norm delta mem t2 in
380
    let con0 = aux t1 t2 neg in
381
    ConstrSet.(union (union con1 con2) con0)
382
  in
383
  let t = V.leafconj (V.proj t) in
384
  big_prod norm_prod ConstrSet.sat (Pair.get t)
385

386 387
and norm_rec delta mem t =
  let norm_rec_array ((oleft,left),rights) =
388
    let rec aux accus seen = function
389
      | [ ] -> ConstrSet.sat
390 391 392 393 394 395
      |(false,_) :: rest when oleft -> aux accus seen rest
      |(b,field) :: rest ->
         let right = seen @ rest in
         snd (Array.fold_left (fun (i,acc) t ->
           let di = diff accus.(i) t in
           let accus' = Array.copy accus in accus'.(i) <- di ;
396 397
           (i+1,ConstrSet.inter acc (ConstrSet.inter (norm delta mem di) (aux accus' [] right)))
         ) (0,ConstrSet.sat) field
398 399
         )
    in
400 401
    let c = Array.fold_left (fun acc t -> ConstrSet.union (norm delta mem t) acc) ConstrSet.sat left in
    ConstrSet.inter (aux left [] rights) c
402
  in
403
  let t = VarRec.leafconj (VarRec.proj t) in
404
  List.fold_left (fun acc (_,p,n) ->
405
    if ConstrSet.is_unsat acc then acc else ConstrSet.inter acc (norm_rec_array (p,n))
406
  ) ConstrSet.sat (get_record t)
407
    
408 409 410 411 412 413 414 415 416 417 418 419 420
  (* arrow(p,{t1 -> t2}) = [t1] v arrow'(t1,any \ t2,p)
   * arrow'(t1,acc,{s1 -> s2} v p) =
   * ([t1\s1] v arrow'(t1\s1,acc,p)) ^
   * ([acc ^ {s2}] v arrow'(t1,acc ^ {s2},p))

   * t1 -> t2 \ s1 -> s2 =
   * [t1] v (([t1\s1] v {[]}) ^ ([t2\s2] v {[]}))

   * Bool -> Bool \ Int -> A =
   * [Bool] v (([Bool\Int] v {[]}) ^ ([Bool\A] v {[]})

   * P(Q v {a}) = {a} v P(Q) v {X v {a} | X \in P(Q) }
   *)
421
and norm_arrow delta mem t =
422 423
  let rec norm_arrow pos neg =
    match neg with
424
    |[] -> ConstrSet.unsat
425
    |(t1,t2) :: n ->
426 427 428
       let t1 = descr t1 and t2 = descr t2 in
       let con1 = norm delta mem t1 in (* [t1] *)
       let con2 = aux t1 (diff any t2) pos in
429
       let con0 = norm_arrow pos n in
430
       ConstrSet.union (ConstrSet.union con1 con2) con0
431
  and aux t1 acc = function
432
    |[] -> ConstrSet.unsat
433 434 435
    |(s1,s2) :: p ->
       let t1s1 = diff t1 (descr s1) in
       let acc1 = cap acc (descr s2) in
436 437
       let con1 = norm delta mem t1s1 in (* [t1 \ s1] *)
       let con2 = norm delta mem acc1 in (* [(Any \ t2) ^ s2] *)
438 439
       let con10 = aux t1s1 acc p in
       let con20 = aux t1 acc1 p in
440 441 442
       let con11 = ConstrSet.union con1 con10 in
       let con22 = ConstrSet.union con2 con20 in
       ConstrSet.inter con11 con22
443
  in
444
  let t = VarArrow.leafconj (VarArrow.proj t) in
445
  big_prod norm_arrow ConstrSet.sat (Pair.get t)
446 447 448 449 450 451



let norm delta t =
  try NormMemoHash.find memo_norm (t,delta)
  with Not_found -> begin
452
    let res = norm delta (NormMemoHash.create 17) t in
453 454 455 456
    NormMemoHash.add memo_norm (t,delta) res; res
  end

  (* merge needs delta because it calls norm recursively *)
457
let rec merge delta cache m =
458
  let res =
459 460
    Line.fold (fun v (inf, sup) acc ->
      (* no need to add new constraints *)
461 462 463 464 465 466 467
      if subtype inf sup  then acc
      else
        let x = diff inf sup in
        if Cache.lookup x cache != None then acc
        else
          let cache, _ = Cache.find ignore x cache in
          let n = norm delta x in
468 469
          if ConstrSet.is_unsat n then raise (UnSatConstr "merge2");
          let c1 = ConstrSet.inter (ConstrSet.(add m empty)) n
470 471
          in
          let c2 =
472
            ConstrSet.fold
473
              (fun m1 acc ->
474 475
                ConstrSet.union acc (merge delta cache m1))
              c1 ConstrSet.empty
476
          in
477 478
          ConstrSet.union c2 acc
    ) m ConstrSet.empty
479
  in
480 481
  if ConstrSet.is_unsat res then ConstrSet.(add m empty)
  else res
482

483
let merge delta m = merge delta Cache.emp m
484

485 486 487 488 489
let solve delta s =
  let add_eq alpha s t acc =
    let beta =
      var Var.(mk ~internal:true ("#" ^ (ident alpha)))
    in
490
    Var.Map.replace alpha (cap (cup s beta) t) acc
491 492 493 494 495 496 497 498 499 500 501 502 503
  in
  let extra_var t acc =
    if is_var t then
      let v, _ = Variable.extract t in
      if Var.Set.mem delta v then acc
      else add_eq v empty any acc
    else acc
  in
  let to_eq_set m =
    Line.fold (fun alpha (s,t) acc ->
      let acc = extra_var t acc in
      let acc = extra_var s acc in
      add_eq alpha s t acc
504
    ) m Var.Map.empty
505 506
  in
  ConstrSet.fold (fun m acc -> (to_eq_set m) :: acc) s []
507 508

let solve delta s =
509
  let res = solve delta s  in
510 511
  DEBUG solve (Format.eprintf "Calling solve (%a, %a), got %a@\n%!"
		 Var.Set.print delta ConstrSet.print s
512 513
		 pp_sl res);
  res
514

515 516 517
let () = Format.pp_set_margin Format.err_formatter 200

let unify (eq_set : Descr.t Var.Map.map) =
518
  let rec loop eq_set accu =
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
    DEBUG unify
      (Format.eprintf "@[<hov 2>  Entering unify_loop @[(%a, %a)@]@\n"
	 Subst.print eq_set Subst.print accu);
    let res =
      if Var.Map.is_empty eq_set then accu else
	begin
	  let (alpha, t), eq_set' = Var.Map.remove_min eq_set in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Choosing equation @[%a = %a@]@]@\n"
			 Var.print alpha Types.Print.pp_type t);
	  let x = Subst.solve_rectype t alpha in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Solving equation @[%a = %a@]@]@\n"
			 Var.print alpha Types.Print.pp_type x);
	  let eq_set' =
            Var.Map.map (fun s -> Subst.single s (alpha,x)) eq_set'
	  in
	  let sigma = loop eq_set' (Var.Map.replace alpha x accu) in
	  let t_alpha = Subst.full x sigma in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Applying substitution @[%a@] to @[%a@], @[<hov 4>got %a@]@]@\n"
			 Subst.print sigma Print.pp_type x Print.pp_type t_alpha);
	  Var.Map.replace alpha t_alpha sigma
	end
    in
    DEBUG unify (Format.eprintf "Leaving unify_loop @[(%a, %a)@], @[<hov 4>got %a@]@]@\n"
		   Subst.print eq_set Subst.print accu Subst.print res);
    res
547
  in
548 549
  let res = loop eq_set Var.Map.empty in
  DEBUG unify (Format.(pp_print_flush err_formatter ()));
550
  res
551

552

553 554 555 556 557 558 559
exception Step1Fail
exception Step2Fail

let tallying delta l =
  let n =
    List.fold_left (fun acc (s,t) ->
      let d = diff s t in
560 561 562 563
      if is_empty d then ConstrSet.sat
      else if no_var d then ConstrSet.unsat
      else ConstrSet.inter acc (norm delta d)
    ) ConstrSet.sat l
564
  in
565
  if ConstrSet.is_unsat n then raise Step1Fail else
566
    let m =
567 568 569
      ConstrSet.fold (fun c acc ->
        try
	  (solve delta (merge delta c)) @ acc
570
        with UnSatConstr _ -> acc
571
      ) n []
572
    in
573 574
    if m == [] then raise Step2Fail;
    List.map unify m
575 576


577
type symsubst = I
578
		| S of Descr.t Var.Map.map
579
		| A of (symsubst * symsubst)
580 581 582

let rec dom = function
  |I -> Var.Set.empty
583
  |S si -> Var.Map.domain si
584 585
  |A (si,sj) -> Var.Set.cup (dom si) (dom sj)

586
(* composition of two symbolic substitution sets sigmaI, sigmaJ .Cartesian product *)
587 588 589 590 591 592 593 594
let (++) sI sJ =
  let bind m f = List.flatten(List.map f m) in
  bind sI (fun si ->
    bind sJ (fun sj ->
      [A(si,sj)]
    )
  )

595
let (>>) t s =
596 597
  Var.Map.fold (fun v t acc ->
    Subst.single acc (v, t)) s t
598 599

(* apply a symbolic substitution si to a type t *)
600 601 602 603
let (@@) t si =
  let vst = ref Var.Set.empty in
  let vsi = ref Var.Set.empty in
  let filter t si =
604
    vsi := dom (S si);
605 606 607 608
    vst := all_vars t;
    not(Var.Set.is_empty (Var.Set.cap !vst !vsi))
  in
  let filterdiff t si sj =
609
    let vsj = dom (S sj) in
610 611 612 613 614 615 616 617 618 619 620 621
    not(Var.Set.is_empty (Var.Set.cap !vst (Var.Set.diff !vsi vsj)))
  in
  let rec aux t = function
    |I -> t
    |S si -> t >> si
    |A (S si,_) when filter t si -> t >> si
    |A (S si,S sj) when filterdiff t si sj -> (t >> sj) >> si
    |A (si,sj) -> aux (aux t sj) si
  in
  aux t si

let domain sl =
622
  List.fold_left (fun acc si -> Var.Set.cup acc (Var.Map.domain si)) Var.Set.empty sl
623 624 625

let codomain ll =
  List.fold_left (fun acc e ->
626
    Var.Map.fold (fun _ v acc ->
627 628 629 630
      Var.Set.cup (all_vars v) acc
    ) e acc
  ) Var.Set.empty ll

631 632
let is_identity = List.for_all (Var.Map.is_empty)
let identity = [Var.Map.empty]
633 634 635 636

let filter f sl =
  if is_identity sl then sl else
    List.fold_left (fun acc si ->
637 638
      let e = Var.Map.filter (fun v _ -> f v) si in
      if Var.Map.is_empty e then acc else e::acc
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
    ) [] sl



let set a i v =
  let len = Array.length !a in
  if i <  len then (!a).(i) <- v
  else begin
    let b = Array.make (2*len+1) empty in
    Array.blit !a 0 b 0 len;
    b.(i) <- v;
    a := b
  end
let get a i = if i < 0 then any else (!a).(i)

654
exception FoundSquareSub of Descr.t Var.Map.map list
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671

let squaresubtype delta s t =
  NormMemoHash.clear memo_norm;
  let ai = ref [| |] in
  let tallying i =
    try
      let s = get ai i in
      let sl = tallying delta [ (s,t) ] in
      raise (FoundSquareSub sl)
    with
      Step1Fail -> (assert (i == 0); raise (UnSatConstr "apply_raw step1"))
    | Step2Fail -> () (* continue *)
  in
  let rec loop i =
    try
      let ss =
        if i = 0 then s
672
        else (cap (Subst.freshen delta s) (get ai (i-1)))
673 674 675 676 677 678 679 680 681 682 683
      in
      set ai i ss;
      tallying i;
      loop (i+1)
    with FoundSquareSub sl -> sl
  in
  loop 0

let is_squaresubtype delta s t =
  try ignore(squaresubtype delta s t);true with UnSatConstr _ -> false

684
exception FoundApply of t * int * int * Descr.t Var.Map.map list
685 686 687 688 689 690

(** find two sets of type substitutions I,J such that
    s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *)

let apply_raw delta s t =
  DEBUG apply_raw (Format.eprintf " @[Entering apply_raw (delta:@[%a@], @[%a@], @[%a@])@\n%!"
691
                     Var.Set.print delta
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
                     Print.pp_type s
                     Print.pp_type t
  );

  NormMemoHash.clear memo_norm;
  let vgamma = Var.mk "Gamma" in
  let gamma = var vgamma in
  let cgamma = cons gamma in
  (* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *)
  let ai = ref [| |]
  and aj = ref [| |] in
  let tallying i j  =
    try
      let s = get ai i in
      let t = arrow (cons (get aj j)) cgamma in
      let sl = tallying delta [ (s,t) ] in
      let new_res =
709
        Subst.clean_type delta (
710
          List.fold_left (fun tacc si ->
711
            cap tacc (Subst.full gamma  si)
712 713 714 715 716 717 718 719 720 721 722 723 724
          ) any sl
        )
      in
      raise (FoundApply(new_res,i,j,sl))
    with
      Step1Fail -> (assert (i == 0 &&  j == 0); raise (UnSatConstr "apply_raw step1"))
    | Step2Fail -> () (* continue *)
  in
  let rec loop i =
    try
      (* Format.printf "Starting expansion %i @\n@." i; *)
      let (ss,tt) =
        if i = 0 then (s,t) else
725 726
          ((cap (Subst.freshen delta s) (get ai (i-1))),
           (cap (Subst.freshen delta t) (get aj (i-1))))
727 728 729 730 731 732 733 734 735 736 737
      in
      set ai i ss;
      set aj i tt;
      for j = 0 to i-1 do
        tallying j i;
        tallying i j;
      done;
      tallying i i;
      loop (i+1)
    with FoundApply (res, i, j, sl) ->
      DEBUG apply_raw (Format.eprintf " Leaving apply_raw (delta:@[%a@], @[%a@], @[%a@]) = @[%a@], @[%a@] @]@\n%!"
738
                         Var.Set.print delta
739 740 741
                         Print.pp_type s
                         Print.pp_type t
                         Print.pp_type res
742
                         pp_sl sl
743 744 745 746 747 748 749 750 751 752
      );
      (sl, get ai i, get aj j, res)
  in
  loop 0

let apply_full delta s t =
  let _,_,_,res = apply_raw delta s t in
  res

let squareapply delta s t = let s,_,_,res = apply_raw delta s t in (s,res)