type_tallying.ml 23.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
open Types

let cap_t d t = cap d (descr t)

let cap_product any_left any_right l =
  List.fold_left
    (fun (d1,d2) (t1,t2) -> (cap_t d1 t1, cap_t d2 t2))
    (any_left,any_right)
    l

exception UnSatConstr of string

13
let pp_sl ppf l =
14
  Utils.pp_list ~delim:("{","}") Subst.print ppf l
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38

type constr = Types.t * Types.t (* lower and
				   upper bounds *)


    (* A comparison function between types that
       is compatible with subtyping. If types are
       not in a subtyping relation, use implementation
       defined order
    *)

let compare_type t1 t2 =
  let inf12 = Types.subtype t1 t2 in
  let inf21 = Types.subtype t2 t1 in
  if inf12 && inf21 then 0
  else if inf12 then -1 else if inf21 then 1 else
      let c = Types.compare t1 t2 in
      assert (c <> 0);
      c


(* A line is a conjunction of constraints *)
module Line = struct

39
  type t = constr Var.Map.map
40

41 42 43
  let singleton = Var.Map.singleton
  let is_empty = Var.Map.is_empty
  let length = Var.Map.length
44 45 46 47 48 49 50 51 52

    (* a set of constraints m1 subsumes a set of constraints m2,
       that is the solutions for m1 contains all the solutions for
       m2 if:
       forall i1 <= v <= s1 in m1,
       there exists i2 <= v <= s2 in m2 such that i1 <= i2 <= v <= s2 <= s1
    *)
  let subsumes map1 map2 =
    List.for_all (fun (v,(i1, s1)) ->
53
      try let i2, s2 = Var.Map.assoc v map2 in
54 55
          subtype i1 i2 && subtype s2 s1
      with Not_found -> false
56
    ) (Var.Map.get map1)
57 58 59 60

  let print ppf map =
    Utils.pp_list ~delim:("{","}") (fun ppf (v, (i,s)) ->
      Format.fprintf ppf "%a <= %a <= %a" Print.pp_type i Var.print v Print.pp_type s
61
    ) ppf (Var.Map.get map)
62 63

  let compare map1 map2 =
64
    Var.Map.compare (fun (i1,s1) (i2,s2) ->
65 66 67 68 69 70 71
      let c = compare_type i1 i2 in
      if c == 0 then compare_type s1 s2
      else c) map1 map2

  let add v (inf, sup) map =
    let new_i, new_s =
      try
72
        let old_i, old_s = Var.Map.assoc v map in
73 74 75 76 77
        cup old_i inf,
        cap old_s sup
      with
        Not_found -> inf, sup
    in
78
    Var.Map.replace v (new_i, new_s) map
79

80 81 82 83
  let join map1 map2 = Var.Map.fold add map1 map2
  let fold = Var.Map.fold
  let empty = Var.Map.empty
  let for_all f m = List.for_all (fun (k,v) -> f k v) (Var.Map.get m)
84
end
85

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
module ConstrSet =
struct

    (* A set of constraint-sets is just a list of Lines,
       that are pairwise "non-subsumable"
    *)
  type t = Line.t list
  let elements t = t
  let empty = []

  let add m l =
    let rec loop m l acc =
      match l with
        [] -> m :: acc
      | mm :: ll ->
         if Line.subsumes m mm then List.rev_append ll (m::acc)
         else if Line.subsumes mm m then List.rev_append ll (mm::acc)
         else loop m ll (mm::acc)
    in
    loop m l []

  let unsat = empty
  let sat = [ Line.empty ]

  let is_unsat m = m == []
  let is_sat m = match m with
      [ l ] when Line.is_empty l -> true
    | _ -> false

  let print ppf s = Utils.pp_list ~delim:("{","}") Line.print ppf s

  let fold f l a = List.fold_left (fun e a -> f a e) a l

  let is_empty l = l == []

  (* Square union : *)
  let union s1 s2 =
    match s1, s2 with
      [], _ -> s2
    | _, [] -> s1
    | _ ->
         (* Invariant: all elements of s1 (resp s2) are pairwise
            incomparable (they don't subsume one another)
            let e1 be an element of s1:
            - if e1 subsumes no element of s2, add e1 to the result
            - if e1 subsumes an element e2 of s2, add e1 to the
            result and remove e2 from s2
            - if an element e2 of s2 subsumes e1, add e2 to the
            result and remove e2 from s2 (and discard e1)

            once we are done for all e1, add the remaining elements from
            s2 to the result.
         *)

       let append e1 s2 result =
         let rec loop s2 accs2 =
           match s2 with
             [] -> accs2, e1::result
           | e2 :: ss2 ->
              if Line.subsumes e1 e2 then List.rev_append ss2 accs2, e1::result
              else if Line.subsumes e2 e1 then List.rev_append ss2 accs2, e2::result
              else loop ss2 (e2::accs2)
148
         in
149 150 151 152 153 154 155 156 157 158 159 160 161
         loop s2 []
       in
       let rec loop s1 s2 result =
         match s1 with
           [] -> List.rev_append s2 result
         | e1 :: ss1 ->
            let new_s2, new_result = append e1 s2 result in
            loop ss1 new_s2 new_result
       in
       loop s1 s2 []

    (* Square intersection *)
    let inter s1 s2 =
162 163 164
      match s1,s2 with
        [], _ | _, [] -> []
      | _ ->
165 166 167 168 169 170 171 172
         (* Perform the cartesian product. For each constraint m1 in s1,
            m2 in s2, we add Line.join m1 m2 to the result.
            Optimisations:
            - we use add to ensure that we do not add something that subsumes
            a constraint set that is already in the result
            - if m1 subsumes m2, it means that whenever m2 holds, so does m1, so
            we only add m2 (note that the condition is reversed w.r.t. union).
         *)
173 174
         fold (fun m1 acc1 ->
           fold (fun m2 acc2 ->
175 176
             let merged = if Line.subsumes m1 m2 then m2
               else if Line.subsumes m2 m1 then m1
177
               else Line.join m1 m2
178 179 180 181 182 183
             in
             add merged acc2
           )
             s2 acc1) s1 []
    let filter = List.filter

184 185 186 187 188 189 190
    let singleton c =
      let cstr = match c with
	  `pos (v, s) -> Line.singleton v (Types.empty, s)
	| `neg (s, v) -> Line.singleton v (s, Types.any)
      in
      [ cstr ]
end
191 192


193 194 195 196
let normatoms _ _ t = if Atoms.is_empty t then ConstrSet.sat else ConstrSet.unsat
let normchars _ _ t = if Chars.is_empty t then ConstrSet.sat else ConstrSet.unsat
let normints _ _ t = if Intervals.is_empty t then ConstrSet.sat else ConstrSet.unsat
let normabstract _ _ t = if Abstracts.is_empty t then ConstrSet.sat else ConstrSet.unsat
197

198 199 200 201 202 203 204 205
let single (type a) (module V : VarType with type Atom.t = a) b v lpos lneg =
  let aux dir l =
    List.fold_left (fun acc va ->
      cap acc (dir (
	match va with
	  `Var v -> var v
	| (`Atm _) as a -> V.(inj (atom a)))))
      any l
206 207
  in
  let id = (fun x -> x) in
208 209
  let t = cap (aux id lpos) (aux neg lneg) in
  if b then neg t else t
210

211
(* check if there exists a toplevel variable : fun (pos,neg) *)
212

213 214 215
let toplevel (type a) (module V : VarType with type Atom.t = a)
    delta norm_rec mem lpos lneg
    =
216 217 218 219 220 221 222 223
  let _compare delta v1 v2 =
    let monov1 = Var.Set.mem delta v1 in
    let monov2 = Var.Set.mem delta v2 in
    if monov1 == monov2 then
      Var.compare v1 v2
    else
      if monov1 then 1 else -1
  in
224 225 226 227 228 229
  let singleton c =
    match c with
      `neg (t, x) when Var.Set.mem delta x
	  && not(subtype t (var x)) -> ConstrSet.unsat
    | `pos (x, t) when Var.Set.mem delta x
	&& not(subtype (var x) t) -> ConstrSet.unsat
230

231 232 233 234 235 236 237 238 239 240
    | _ -> ConstrSet.singleton c
  in
  match lpos, lneg with
    [], (`Var x)::neg ->
      let t = single (module V) false x [] neg in
      singleton (`neg (t, x))

  | (`Var x)::pos,[] ->
     let t = single (module V) true x pos [] in
     singleton (`pos (x,t))
241

242
  | (`Var x)::pos, (`Var y)::neg ->
243
     if _compare delta x y < 0 then
244 245
       let t = single (module V) true x pos lneg in
       singleton (`pos (x,t))
246
     else
247 248
       let t = single (module V) false y lpos neg in
       singleton (`neg (t, y))
249

250 251 252
  | [`Atm _ ], (`Var x)::neg ->
     let t = single (module V) false x lpos neg in
     singleton (`neg (t, x))
253

254 255
  | [ `Atm t ],  [ ] -> norm_rec delta mem t
  | __ -> assert false
256

257 258

let big_prod f acc l =
259
  List.fold_left (fun acc (pos,neg) ->
260
    ConstrSet.inter  acc (f pos neg)
261 262 263 264 265 266
  ) acc l

module NormMemoHash = Hashtbl.Make(Custom.Pair(Descr)(Var.Set))

let memo_norm = NormMemoHash.create 17

267
let rec norm delta mem t =
268 269 270 271 272
  DEBUG normrec (
    Format.eprintf
      " @[Entering norm rec(%a):@\n" Print.pp_type t);
  let res =
    try
273 274
      (* If we find it in the global hashtable,
	 we are finished *)
275 276 277 278 279 280 281 282 283 284
      let res = NormMemoHash.find memo_norm (t, delta) in
      DEBUG normrec (Format.eprintf
                       "@[ - Result found in global table @]@\n");
      res
    with
      Not_found ->
        try
          let finished, cst = NormMemoHash.find mem (t, delta) in
          DEBUG normrec (Format.eprintf
                           "@[ - Result found in local table, finished = %b @]@\n" finished);
285
          if finished then cst else ConstrSet.sat
286 287 288 289 290 291 292
        with
          Not_found ->
            begin
              let res =
                  (* base cases *)
                if is_empty t then begin
                  DEBUG normrec (Format.eprintf "@[ - Empty type case @]@\n");
293
                  ConstrSet.sat
294 295
                end else if no_var t then begin
                  DEBUG normrec (Format.eprintf "@[ - No var case @]@\n");
296
                  ConstrSet.unsat
297 298 299 300
                end else if is_var t then begin
                  let (v,p) = Variable.extract t in
                  if Var.Set.mem delta v then begin
                    DEBUG normrec (Format.eprintf "@[ - Monomorphic var case @]@\n");
301
                    ConstrSet.unsat (* if it is monomorphic, unsat *)
302 303
                  end else begin
                    DEBUG normrec (Format.eprintf "@[ - Polymorphic var case @]@\n");
304 305 306
                    (* otherwise, create a single constraint according to its polarity *)
                    let s = if p then (`pos (v,empty)) else (`neg (any,v)) in
                    ConstrSet.singleton s
307 308 309
                  end
                end else begin (* type is not empty and is not a variable *)
                  DEBUG normrec (Format.eprintf "@[ - Inductive case:@\n");
310
                  let mem = NormMemoHash.add mem (t,delta) (false, ConstrSet.sat); mem in
311
                  let t = Iter.simplify t in
312 313 314 315 316 317 318 319
                  let aux (type a) (module V : VarType with type Atom.t = a)
		      norm_constr acc t
		      =
                    big_prod (toplevel
				(module V)
				delta norm_constr mem)
		      acc
		      V.(get (proj t))
320
                  in
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
		  let acc = ConstrSet.sat in
                  let acc = aux (module VarAtoms) normatoms acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Atoms constraints: %a @]@\n" ConstrSet.print acc);
                  let acc = aux (module VarIntervals) normints acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Ints constraints: %a @]@\n" ConstrSet.print acc);
                  let acc = aux (module VarChars) normchars acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Chars constraints: %a @]@\n" ConstrSet.print acc);
                  let acc = aux (module VarTimes) normpair acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Times constraints: %a @]@\n" ConstrSet.print acc);
                  let acc = aux (module VarXml) normpair acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Xml constraints: %a @]@\n" ConstrSet.print acc);
                  let acc = aux (module VarArrow) normarrow acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Arrow constraints: %a @]@\n" ConstrSet.print acc);
                  let acc = aux (module VarRec) normrec acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Record constraints: %a @]@\n" ConstrSet.print acc);
                  let acc = aux (module VarAbstracts) normabstract acc t in
                  DEBUG normrec (Format.eprintf "@[ - After Abstract constraints: %a @]@\n" ConstrSet.print acc);
338
                  DEBUG normrec (Format.eprintf "@]@\n");
339 340
		  acc

341 342 343 344 345 346 347 348
                end
              in
              NormMemoHash.replace mem (t, delta) (true,res); res
            end
  in
  DEBUG normrec (Format.eprintf
                   "Leaving norm rec(%a) = %a@]@\n%!"
                   Print.pp_type t
349
                   ConstrSet.print res
350 351 352 353 354 355 356 357 358 359 360 361 362
  );
  res

  (* (t1,t2) = intersection of all (fst pos,snd pos) \in P
   * (s1,s2) \in N
   * [t1] v [t2] v ( [t1 \ s1] ^ [t2 \ s2] ^
   * [t1 \ s1 \ s1_1] ^ [t2 \ s2 \ s2_1 ] ^
   * [t1 \ s1 \ s1_1 \ s1_2] ^ [t2 \ s2 \ s2_1 \ s2_2 ] ^ ... )
   *
   * prod(p,n) = [t1] v [t2] v prod'(t1,t2,n)
   * prod'(t1,t2,{s1,s2} v n) = ([t1\s1] v prod'(t1\s1,t2,n)) ^
   *                            ([t2\s2] v prod'(t1,t2\s2,n))
   * *)
363 364

and normpair delta mem t =
365 366
  let norm_prod pos neg =
    let rec aux t1 t2 = function
367
      |[] -> ConstrSet.unsat
368 369 370
      |(s1,s2) :: rest -> begin
        let z1 = diff t1 (descr s1) in
        let z2 = diff t2 (descr s2) in
371 372
        let con1 = norm delta mem z1 in
        let con2 = norm delta mem z2 in
373 374
        let con10 = aux z1 t2 rest in
        let con20 = aux t1 z2 rest in
375 376 377
        let con11 = ConstrSet.union con1 con10 in
        let con22 = ConstrSet.union con2 con20 in
        ConstrSet.inter con11 con22
378 379 380 381
      end
    in
      (* cap_product return the intersection of all (fst pos,snd pos) *)
    let (t1,t2) = cap_product any any pos in
382 383
    let con1 = norm delta mem t1 in
    let con2 = norm delta mem t2 in
384
    let con0 = aux t1 t2 neg in
385
    ConstrSet.(union (union con1 con2) con0)
386
  in
387
  big_prod norm_prod ConstrSet.sat (Pair.get t)
388

389
and normrec delta mem t =
390 391
  let norm_rec ((oleft,left),rights) =
    let rec aux accus seen = function
392
      | [ ] -> ConstrSet.sat
393 394 395 396 397 398
      |(false,_) :: rest when oleft -> aux accus seen rest
      |(b,field) :: rest ->
         let right = seen @ rest in
         snd (Array.fold_left (fun (i,acc) t ->
           let di = diff accus.(i) t in
           let accus' = Array.copy accus in accus'.(i) <- di ;
399 400
           (i+1,ConstrSet.inter acc (ConstrSet.inter (norm delta mem di) (aux accus' [] right)))
         ) (0,ConstrSet.sat) field
401 402
         )
    in
403 404
    let c = Array.fold_left (fun acc t -> ConstrSet.union (norm delta mem t) acc) ConstrSet.sat left in
    ConstrSet.inter (aux left [] rights) c
405 406
  in
  List.fold_left (fun acc (_,p,n) ->
407 408
    if ConstrSet.is_unsat acc then acc else ConstrSet.inter acc (norm_rec (p,n))
  ) ConstrSet.sat (get_record t)
409 410 411 412 413 414 415 416 417 418 419 420 421 422

  (* arrow(p,{t1 -> t2}) = [t1] v arrow'(t1,any \ t2,p)
   * arrow'(t1,acc,{s1 -> s2} v p) =
   * ([t1\s1] v arrow'(t1\s1,acc,p)) ^
   * ([acc ^ {s2}] v arrow'(t1,acc ^ {s2},p))

   * t1 -> t2 \ s1 -> s2 =
   * [t1] v (([t1\s1] v {[]}) ^ ([t2\s2] v {[]}))

   * Bool -> Bool \ Int -> A =
   * [Bool] v (([Bool\Int] v {[]}) ^ ([Bool\A] v {[]})

   * P(Q v {a}) = {a} v P(Q) v {X v {a} | X \in P(Q) }
   *)
423
and normarrow delta mem t =
424 425
  let rec norm_arrow pos neg =
    match neg with
426
    |[] -> ConstrSet.unsat
427
    |(t1,t2) :: n ->
428 429 430
       let t1 = descr t1 and t2 = descr t2 in
       let con1 = norm delta mem t1 in (* [t1] *)
       let con2 = aux t1 (diff any t2) pos in
431
       let con0 = norm_arrow pos n in
432
       ConstrSet.union (ConstrSet.union con1 con2) con0
433
  and aux t1 acc = function
434
    |[] -> ConstrSet.unsat
435 436 437
    |(s1,s2) :: p ->
       let t1s1 = diff t1 (descr s1) in
       let acc1 = cap acc (descr s2) in
438 439
       let con1 = norm delta mem t1s1 in (* [t1 \ s1] *)
       let con2 = norm delta mem acc1 in (* [(Any \ t2) ^ s2] *)
440 441
       let con10 = aux t1s1 acc p in
       let con20 = aux t1 acc1 p in
442 443 444
       let con11 = ConstrSet.union con1 con10 in
       let con22 = ConstrSet.union con2 con20 in
       ConstrSet.inter con11 con22
445
  in
446
  big_prod norm_arrow ConstrSet.sat (Pair.get t)
447 448 449 450 451 452



let norm delta t =
  try NormMemoHash.find memo_norm (t,delta)
  with Not_found -> begin
453
    let res = norm delta (NormMemoHash.create 17) t in
454 455 456 457
    NormMemoHash.add memo_norm (t,delta) res; res
  end

  (* merge needs delta because it calls norm recursively *)
458
let rec merge delta cache m =
459
  let res =
460 461
    Line.fold (fun v (inf, sup) acc ->
      (* no need to add new constraints *)
462 463 464 465 466 467 468
      if subtype inf sup  then acc
      else
        let x = diff inf sup in
        if Cache.lookup x cache != None then acc
        else
          let cache, _ = Cache.find ignore x cache in
          let n = norm delta x in
469 470
          if ConstrSet.is_unsat n then raise (UnSatConstr "merge2");
          let c1 = ConstrSet.inter (ConstrSet.(add m empty)) n
471 472
          in
          let c2 =
473
            ConstrSet.fold
474
              (fun m1 acc ->
475 476
                ConstrSet.union acc (merge delta cache m1))
              c1 ConstrSet.empty
477
          in
478 479
          ConstrSet.union c2 acc
    ) m ConstrSet.empty
480
  in
481 482
  if ConstrSet.is_unsat res then ConstrSet.(add m empty)
  else res
483

484
let merge delta m = merge delta Cache.emp m
485

486 487 488 489 490
let solve delta s =
  let add_eq alpha s t acc =
    let beta =
      var Var.(mk ~internal:true ("#" ^ (ident alpha)))
    in
491
    Var.Map.replace alpha (cap (cup s beta) t) acc
492 493 494 495 496 497 498 499 500 501 502 503 504
  in
  let extra_var t acc =
    if is_var t then
      let v, _ = Variable.extract t in
      if Var.Set.mem delta v then acc
      else add_eq v empty any acc
    else acc
  in
  let to_eq_set m =
    Line.fold (fun alpha (s,t) acc ->
      let acc = extra_var t acc in
      let acc = extra_var s acc in
      add_eq alpha s t acc
505
    ) m Var.Map.empty
506 507
  in
  ConstrSet.fold (fun m acc -> (to_eq_set m) :: acc) s []
508 509

let solve delta s =
510
  let res = solve delta s  in
511 512
  DEBUG solve (Format.eprintf "Calling solve (%a, %a), got %a@\n%!"
		 Var.Set.print delta ConstrSet.print s
513 514
		 pp_sl res);
  res
515

516 517 518
let () = Format.pp_set_margin Format.err_formatter 200

let unify (eq_set : Descr.t Var.Map.map) =
519
  let rec loop eq_set accu =
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
    DEBUG unify
      (Format.eprintf "@[<hov 2>  Entering unify_loop @[(%a, %a)@]@\n"
	 Subst.print eq_set Subst.print accu);
    let res =
      if Var.Map.is_empty eq_set then accu else
	begin
	  let (alpha, t), eq_set' = Var.Map.remove_min eq_set in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Choosing equation @[%a = %a@]@]@\n"
			 Var.print alpha Types.Print.pp_type t);
	  let x = Subst.solve_rectype t alpha in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Solving equation @[%a = %a@]@]@\n"
			 Var.print alpha Types.Print.pp_type x);
	  let eq_set' =
            Var.Map.map (fun s -> Subst.single s (alpha,x)) eq_set'
	  in
	  let sigma = loop eq_set' (Var.Map.replace alpha x accu) in
	  let t_alpha = Subst.full x sigma in
	  DEBUG unify (Format.eprintf
			 "@[<hov 2>- Applying substitution @[%a@] to @[%a@], @[<hov 4>got %a@]@]@\n"
			 Subst.print sigma Print.pp_type x Print.pp_type t_alpha);
	  Var.Map.replace alpha t_alpha sigma
	end
    in
    DEBUG unify (Format.eprintf "Leaving unify_loop @[(%a, %a)@], @[<hov 4>got %a@]@]@\n"
		   Subst.print eq_set Subst.print accu Subst.print res);
    res
548
  in
549 550
  let res = loop eq_set Var.Map.empty in
  DEBUG unify (Format.(pp_print_flush err_formatter ()));
551
  res
552

553

554 555 556 557 558 559 560
exception Step1Fail
exception Step2Fail

let tallying delta l =
  let n =
    List.fold_left (fun acc (s,t) ->
      let d = diff s t in
561 562 563 564
      if is_empty d then ConstrSet.sat
      else if no_var d then ConstrSet.unsat
      else ConstrSet.inter acc (norm delta d)
    ) ConstrSet.sat l
565
  in
566
  if ConstrSet.is_unsat n then raise Step1Fail else
567
    let m =
568 569 570
      ConstrSet.fold (fun c acc ->
        try
	  (solve delta (merge delta c)) @ acc
571
        with UnSatConstr _ -> acc
572
      ) n []
573
    in
574 575
    if m == [] then raise Step2Fail;
    List.map unify m
576 577


578
type symsubst = I
579
		| S of Descr.t Var.Map.map
580
		| A of (symsubst * symsubst)
581 582 583

let rec dom = function
  |I -> Var.Set.empty
584
  |S si -> Var.Map.domain si
585 586
  |A (si,sj) -> Var.Set.cup (dom si) (dom sj)

587
(* composition of two symbolic substitution sets sigmaI, sigmaJ .Cartesian product *)
588 589 590 591 592 593 594 595
let (++) sI sJ =
  let bind m f = List.flatten(List.map f m) in
  bind sI (fun si ->
    bind sJ (fun sj ->
      [A(si,sj)]
    )
  )

596
let (>>) t s =
597 598
  Var.Map.fold (fun v t acc ->
    Subst.single acc (v, t)) s t
599 600

(* apply a symbolic substitution si to a type t *)
601 602 603 604
let (@@) t si =
  let vst = ref Var.Set.empty in
  let vsi = ref Var.Set.empty in
  let filter t si =
605
    vsi := dom (S si);
606 607 608 609
    vst := all_vars t;
    not(Var.Set.is_empty (Var.Set.cap !vst !vsi))
  in
  let filterdiff t si sj =
610
    let vsj = dom (S sj) in
611 612 613 614 615 616 617 618 619 620 621 622
    not(Var.Set.is_empty (Var.Set.cap !vst (Var.Set.diff !vsi vsj)))
  in
  let rec aux t = function
    |I -> t
    |S si -> t >> si
    |A (S si,_) when filter t si -> t >> si
    |A (S si,S sj) when filterdiff t si sj -> (t >> sj) >> si
    |A (si,sj) -> aux (aux t sj) si
  in
  aux t si

let domain sl =
623
  List.fold_left (fun acc si -> Var.Set.cup acc (Var.Map.domain si)) Var.Set.empty sl
624 625 626

let codomain ll =
  List.fold_left (fun acc e ->
627
    Var.Map.fold (fun _ v acc ->
628 629 630 631
      Var.Set.cup (all_vars v) acc
    ) e acc
  ) Var.Set.empty ll

632 633
let is_identity = List.for_all (Var.Map.is_empty)
let identity = [Var.Map.empty]
634 635 636 637

let filter f sl =
  if is_identity sl then sl else
    List.fold_left (fun acc si ->
638 639
      let e = Var.Map.filter (fun v _ -> f v) si in
      if Var.Map.is_empty e then acc else e::acc
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
    ) [] sl



let set a i v =
  let len = Array.length !a in
  if i <  len then (!a).(i) <- v
  else begin
    let b = Array.make (2*len+1) empty in
    Array.blit !a 0 b 0 len;
    b.(i) <- v;
    a := b
  end
let get a i = if i < 0 then any else (!a).(i)

655
exception FoundSquareSub of Descr.t Var.Map.map list
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672

let squaresubtype delta s t =
  NormMemoHash.clear memo_norm;
  let ai = ref [| |] in
  let tallying i =
    try
      let s = get ai i in
      let sl = tallying delta [ (s,t) ] in
      raise (FoundSquareSub sl)
    with
      Step1Fail -> (assert (i == 0); raise (UnSatConstr "apply_raw step1"))
    | Step2Fail -> () (* continue *)
  in
  let rec loop i =
    try
      let ss =
        if i = 0 then s
673
        else (cap (Subst.freshen delta s) (get ai (i-1)))
674 675 676 677 678 679 680 681 682 683 684
      in
      set ai i ss;
      tallying i;
      loop (i+1)
    with FoundSquareSub sl -> sl
  in
  loop 0

let is_squaresubtype delta s t =
  try ignore(squaresubtype delta s t);true with UnSatConstr _ -> false

685
exception FoundApply of t * int * int * Descr.t Var.Map.map list
686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709

(** find two sets of type substitutions I,J such that
    s @@ sigma_i < t @@ sigma_j for all i \in I, j \in J *)

let apply_raw delta s t =
  DEBUG apply_raw (Format.eprintf " @[Entering apply_raw (delta:@[%a@], @[%a@], @[%a@])@\n%!"
                     Var.Set.pp delta
                     Print.pp_type s
                     Print.pp_type t
  );

  NormMemoHash.clear memo_norm;
  let vgamma = Var.mk "Gamma" in
  let gamma = var vgamma in
  let cgamma = cons gamma in
  (* cell i of ai contains /\k<=i s_k, cell j of aj contains /\k<=j t_k *)
  let ai = ref [| |]
  and aj = ref [| |] in
  let tallying i j  =
    try
      let s = get ai i in
      let t = arrow (cons (get aj j)) cgamma in
      let sl = tallying delta [ (s,t) ] in
      let new_res =
710
        Subst.clean_type delta (
711
          List.fold_left (fun tacc si ->
712
            cap tacc (Subst.full gamma  si)
713 714 715 716 717 718 719 720 721 722 723 724 725
          ) any sl
        )
      in
      raise (FoundApply(new_res,i,j,sl))
    with
      Step1Fail -> (assert (i == 0 &&  j == 0); raise (UnSatConstr "apply_raw step1"))
    | Step2Fail -> () (* continue *)
  in
  let rec loop i =
    try
      (* Format.printf "Starting expansion %i @\n@." i; *)
      let (ss,tt) =
        if i = 0 then (s,t) else
726 727
          ((cap (Subst.freshen delta s) (get ai (i-1))),
           (cap (Subst.freshen delta t) (get aj (i-1))))
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
      in
      set ai i ss;
      set aj i tt;
      for j = 0 to i-1 do
        tallying j i;
        tallying i j;
      done;
      tallying i i;
      loop (i+1)
    with FoundApply (res, i, j, sl) ->
      DEBUG apply_raw (Format.eprintf " Leaving apply_raw (delta:@[%a@], @[%a@], @[%a@]) = @[%a@], @[%a@] @]@\n%!"
                         Var.Set.pp delta
                         Print.pp_type s
                         Print.pp_type t
                         Print.pp_type res
743
                         ConstrSet.printl sl
744 745 746 747 748 749 750 751 752 753
      );
      (sl, get ai i, get aj j, res)
  in
  loop 0

let apply_full delta s t =
  let _,_,_,res = apply_raw delta s t in
  res

let squareapply delta s t = let s,_,_,res = apply_raw delta s t in (s,res)