mlstub.ml 15.3 KB
Newer Older
1
2
3
4
#load "q_MLast.cmo";;

(* TODO:
   - optimizations: generate labels and atoms only once.
5
   - translate record to open record on positive occurence
6
7
8
9
10
11
12
13
14
15
16
17
18
19
*)


open Mltypes
open Ident

module IntMap = 
  Map.Make(struct type t = int let compare : t -> t -> int = compare end)

module IntHash =
  Hashtbl.Make(struct type t = int let hash i = i let equal i j = i == j end)

(* Compute CDuce type *)

20
21
let vars = ref [||]

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
let memo_typ = IntHash.create 13

let atom lab = Types.atom (Atoms.atom (Atoms.V.mk_ascii lab))
let label lab = LabelPool.mk (Ns.empty, U.mk lab)
let bigcup f l = List.fold_left (fun accu x -> Types.cup accu (f x)) Types.empty l

let rec typ t =
  try IntHash.find memo_typ t.uid
  with Not_found ->
    let node = Types.make () in
    IntHash.add memo_typ t.uid node;
    Types.define node (typ_descr t.def);
    node

and typ_descr = function
  | Link t -> typ_descr t.def
38
  | Arrow (_,t,s) -> Types.arrow (typ t) (typ s)
39
40
  | Tuple tl -> Types.tuple (List.map typ tl)
  | PVariant l -> bigcup pvariant l
41
42
  | Variant (_,l,_) -> bigcup variant l
  | Record (_,l,_) ->
43
44
45
46
47
      let l = List.map (fun (lab,t) -> label lab, typ t) l in
      Types.record' (false,(LabelMap.from_list_disj l))
  | Abstract "int" -> Builtin_defs.caml_int
  | Abstract "char" -> Builtin_defs.char_latin1
  | Abstract "string" -> Builtin_defs.string_latin1
48
  | Abstract s -> Types.abstract (Types.Abstract.atom s)
49
  | Builtin ("list", [t]) -> Types.descr (Sequence.star_node (typ t))
50
  | Builtin ("Pervasives.ref", [t]) -> Builtin_defs.ref_type (typ t)
51
52
  | Builtin ("CDuce_all.Value.t", []) -> Types.any
  | Builtin ("unit", []) -> Sequence.nil_type
53
  | Var i -> Types.descr (!vars).(i)
54
55
56
57
58
59
60
61
62
63
64
65
66
  | _ -> assert false
	   
and pvariant = function
  | (lab, None) -> atom lab
  | (lab, Some t) -> Types.times (Types.cons (atom lab)) (typ t)

and variant = function
  | (lab, []) -> atom lab
  | (lab, c) -> Types.tuple (Types.cons (atom lab) :: List.map typ c)


(* Syntactic tools *)

67
68
69
70
let var_counter = ref 0
let mk_var _ =
  incr var_counter;
  Printf.sprintf "x%i" !var_counter
71

72
let mk_vars = List.map mk_var
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

let loc = (-1,-1)

let let_in p e body =
  <:expr< let $list:[ p, e ]$ in $body$ >>

let atom_ascii lab =
  <:expr< Value.atom_ascii $str: String.escaped lab$ >>

let label_ascii lab =
  <:expr< Value.label_ascii $str: String.escaped lab$ >>

let pair e1 e2 = <:expr< Value.Pair ($e1$,$e2$) >>

let pmatch e l = 
  let l = List.map (fun (p,e) -> p,None,e) l in
  <:expr< match $e$ with [ $list:l$ ] >>

let rec matches ine oute = function
  | [v1;v2] ->
      let_in <:patt<($lid:v1$,$lid:v2$)>> <:expr< Value.get_pair $ine$ >> oute
  | v::vl ->
95
96
97
      let r = mk_var () in
      let oute = matches <:expr< $lid:r$ >> oute vl in
      let_in <:patt<($lid:v$,$lid:r$)>> <:expr< Value.get_pair $ine$ >> oute
98
99
100
101
102
  | [] -> assert false

let list_lit el =
  List.fold_right (fun a e -> <:expr< [$a$ :: $e$] >>) el <:expr< [] >>

103
104
105
106
107
108
109
110
let protect e f =
  match e with
    | <:expr< $lid:x$ >> -> f e
    | e ->
	let x = mk_var () in
	let r = f <:expr< $lid:x$ >> in
	<:expr< let $lid:x$ = $e$ in $r$ >> 

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
(* Registered types *)

module HashTypes = Hashtbl.Make(Types)
let registered_types = HashTypes.create 13
let nb_registered_types = ref 0

let register_type t =
  let n =
    try HashTypes.find registered_types t
    with Not_found ->
      let i = !nb_registered_types in
      HashTypes.add registered_types t i;
      incr nb_registered_types;
      i 
  in
  <:expr< types.($int:string_of_int n$) >>

let get_registered_types () =
  let a = Array.make !nb_registered_types Types.empty in
  HashTypes.iter (fun t i -> a.(i) <- t) registered_types;
  a

133
134
(* OCaml -> CDuce conversions *)

135

136
137
138
139
140
141
142
143
144
let to_cd_gen = ref []

let to_cd_fun_name t = 
  Printf.sprintf "to_cd_%i" t.uid

let to_cd_fun t =
  to_cd_gen := t :: !to_cd_gen;
  to_cd_fun_name t

145
146
147
148
149
150
151
152
153
let to_ml_gen = ref []

let to_ml_fun_name t =
  Printf.sprintf "to_ml_%i" t.uid

let to_ml_fun t =
  to_ml_gen := t :: !to_ml_gen;
  to_ml_fun_name t

154
155
156
157
158
159
160
161
162
163
let rec tuple = function
  | [v] -> v
  | v::l -> <:expr< Value.Pair ($v$, $tuple l$) >> 
  | [] -> assert false

let pat_tuple vars = 
  let pl = List.map (fun id -> <:patt< $lid:id$ >>) vars in
  <:patt< ($list:pl$) >>


164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
let call_lab f l x =
  if l = "" then <:expr< $f$ $x$ >>
  else
    if l.[0] = '?' then 
      let l = String.sub l 1 (String.length l - 1) in
      <:expr< $f$ (? $l$ : $x$) >>
    else 
      <:expr< $f$ (~ $l$ : $x$) >>

let abstr_lab l x res =
  if l = "" then <:expr< fun $lid:x$ -> $res$ >>
  else
    if l.[0] = '?' then 
      let l = String.sub l 1 (String.length l - 1) in
      <:expr< fun ? $l$ : ( $lid:x$ ) -> $res$ >>
    else
      <:expr< fun ~ $l$ : $lid:x$ -> $res$ >>



184
185
186
187
188
189
190
191
let rec to_cd e t =
(*  Format.fprintf Format.std_formatter "to_cd %a [uid=%i; recurs=%i]@."
    Mltypes.print t t.uid t.recurs; *)
  if t.recurs > 0 then <:expr< $lid:to_cd_fun t$ $e$ >>
  else to_cd_descr e t.def

and to_cd_descr e = function
  | Link t -> to_cd e t
192
193
  | Arrow (l,t,s) -> 
      (* let y = <...> in Value.Abstraction ([t,s], fun x -> s(y ~l:(t(x))) *)
194
195
196
197
      protect e 
      (fun y ->
	 let x = mk_var () in
	 let arg = to_ml <:expr< $lid:x$ >> t in
198
	 let res = to_cd (call_lab y l arg) s in
199
200
201
202
203
	 let abs = <:expr< fun $lid:x$ -> $res$ >> in
	 let tt = register_type (Types.descr (typ t)) in
	 let ss = register_type (Types.descr (typ s)) in
	 <:expr< Value.Abstraction ([($tt$,$ss$)],$abs$) >>
      )
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
  | Tuple tl -> 
      (* let (x1,...,xn) = ... in Value.Pair (t1(x1), Value.Pair(...,tn(xn))) *)
      let vars = mk_vars tl in
      let_in (pat_tuple vars) e (tuple (tuple_to_cd tl vars))
  | PVariant l ->
      (* match <...> with 
	 | `A -> Value.atom_ascii "A" 
	 | `B x -> Value.Pair (Value.atom_ascii "B",t(x))
      *)
      let cases = 
	List.map
	  (function 
	     | (lab,None) -> <:patt< `$lid:lab$ >>, atom_ascii lab
	     | (lab,Some t) -> 
		 <:patt< `$lid:lab$ x >>, 
		 pair (atom_ascii lab) (to_cd <:expr< x >> t)
	  ) l in
      pmatch e cases
222
  | Variant (p,l,_) ->
223
      (* match <...> with 
224
225
	 | P.A -> Value.atom_ascii "A" 
	 | P.B (x1,x2,..) -> Value.Pair (Value.atom_ascii "B",...,Value.Pair(tn(x)))
226
227
228
229
230
231
232
      *)
      let cases = 
	List.map
	  (function 
	     | (lab,[]) -> <:patt< $uid:lab$ >>, atom_ascii lab
	     | (lab,tl) -> 
		 let vars = mk_vars tl in
233
		 <:patt< $lid:p^lab$ $pat_tuple vars$ >>,
234
235
236
		 tuple (atom_ascii lab :: tuple_to_cd tl vars)
	  ) l in
      pmatch e cases
237
238
  | Record (p,l,_) ->
      (* let x = <...> in Value.record [ l1,t1(x.P.l1); ...; ln,x.P.ln ] *)
239
240
241
242
243
      protect e
      (fun x ->
	 let l = 
	   List.map
	     (fun (lab,t) ->
244
		let e = to_cd <:expr<$x$.$lid:p^lab$>> t in
245
246
247
248
		<:expr< ($label_ascii lab$, $e$) >>)
	     l
	 in
	 <:expr< Value.record $list_lit l$ >>)
249
      
250
251
252
  | Abstract "int" -> <:expr< Value.ocaml2cduce_int $e$ >>
  | Abstract "char" -> <:expr< Value.ocaml2cduce_char $e$ >>
  | Abstract "string" -> <:expr< Value.ocaml2cduce_string $e$ >>
253
  | Abstract s -> <:expr< Value.abstract $str:String.escaped s$ $e$ >>
254
255
256
  | Builtin ("list",[t]) ->
      (* Value.sequence_rev (List.rev_map fun_t <...>) *)
      <:expr< Value.sequence_rev (List.rev_map $lid:to_cd_fun t$ $e$) >>
257
  | Builtin ("Pervasives.ref",[t]) ->
258
259
260
261
262
263
264
265
266
267
268
269
270
      (* let x = <...> in 
         Value.mk_ext_ref t (fun () -> t(!x)) (fun y -> x := t'(y)) *)
      protect e 
      (fun e ->
	 let y = mk_var () in
	 let tt = register_type (Types.descr (typ t)) in
	 let get_x = <:expr< $e$.val >> in
	 let get = <:expr< fun () -> $to_cd get_x t$ >> in
	 let tr_y = to_ml <:expr< $lid:y$ >> t in
	 let set = <:expr< fun $lid:y$ -> $e$.val := $tr_y$ >> in
	 <:expr< Value.mk_ext_ref $tt$ $get$ $set$ >>
      )

271
272
  | Builtin ("CDuce_all.Value.t", []) -> e
  | Builtin ("unit", []) -> <:expr< do { $e$; Value.nil } >>
273
  | Var _ -> e
274
275
276
277
278
279
280
281
  | _ -> assert false

and tuple_to_cd tl vars = List.map2 (fun t id -> to_cd <:expr< $lid:id$ >> t) tl vars

(* CDuce -> OCaml conversions *)



282
and to_ml e t =
283
284
285
286
287
288
289
(*  Format.fprintf Format.std_formatter "to_ml %a@."
    Mltypes.print t; *)
  if t.recurs > 0 then <:expr< $lid:to_ml_fun t$ $e$ >>
  else to_ml_descr e t.def

and to_ml_descr e = function
  | Link t -> to_ml e t
290
291
  | Arrow (l,t,s) -> 
      (* let y = <...> in fun ~l:x -> s(Eval.eval_apply y (t(x))) *)
292
293
294
295
296
      protect e 
      (fun y ->
	 let x = mk_var () in
	 let arg = to_cd <:expr< $lid:x$ >> t in
	 let res = to_ml <:expr< Eval.eval_apply $y$ $arg$ >> s in
297
	 abstr_lab l x res
298
      )
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

  | Tuple tl -> 
      (* let (x1,r) = Value.get_pair <...> in
         let (x2,r) = Value.get_pair r in
         ...
         let (xn-1,xn) = Value.get_pair r in
	 (t1(x1),...,tn(xn)) *)

      let vars = mk_vars tl in
      let el = tuple_to_ml tl vars in
      matches e <:expr< ($list:el$) >> vars
  | PVariant l ->
      (* match Value.get_variant <...> with 
	 | "A",None -> `A 
	 | "B",Some x -> `B (t(x))
314
	 | _ -> assert false
315
      *)
316
      let x = mk_var () in
317
318
319
320
321
322
323
      let cases = 
	List.map 
	  (function 
	     | (lab,None) -> 
		 <:patt< ($str: String.escaped lab$, None) >>,
		 <:expr< `$lid:lab$ >>
	     | (lab,Some t) ->
324
325
326
327
		 let x = mk_var () in
		 let ex = <:expr< $lid:x$ >> in
		 <:patt< ($str: String.escaped lab$, Some $lid:x$) >>,
		 <:expr< `$lid:lab$ $to_ml ex t$ >>
328
	  ) l in
329
      let cases = cases @ [ <:patt< _ >>, <:expr< assert False >> ] in
330
      pmatch <:expr< Value.get_variant $e$ >> cases
331
  | Variant (_,l,false) ->
332
      failwith "Private Sum type"
333
  | Variant (p,l,true) ->
334
      (* match Value.get_variant <...> with 
335
	 | "A",None -> P.A 
336
337
338
339
340
341
342
343
344
345
	 | "B",Some x -> let (x1,r) = x in ... 
      *)
      let cases = 
	List.map 
	  (function 
	     | (lab,[]) -> 
		 <:patt< ($str: String.escaped lab$, None) >>,
		 (match lab with (* Stupid Camlp4 *)
		    | "true" -> <:expr< True >>
		    | "false" -> <:expr< False >>
346
		    | lab -> <:expr< $lid:p^lab$ >>)
347
	     | (lab,[t]) ->
348
349
350
		 let x = mk_var () in
		 let ex = <:expr< $lid:x$ >> in
		 <:patt< ($str: String.escaped lab$, Some $lid:x$) >>,
351
		 <:expr< $lid:p^lab$ $to_ml ex t$ >>
352
353
354
	     | (lab,tl) ->
		 let vars = mk_vars tl in
		 let el = tuple_to_ml tl vars in
355
356
357
		 let x = mk_var () in
		 <:patt< ($str: String.escaped lab$, Some $lid:x$) >>,
		 matches <:expr< $lid:x$ >> 
358
		         <:expr< $lid:p^lab$ ($list:el$) >> vars
359
	  ) l in
360
      let cases = cases @ [ <:patt< _ >>, <:expr< assert False >> ] in
361
      pmatch <:expr< Value.get_variant $e$ >> cases
362
  | Record (_,l,false) ->
363
      failwith "Private Record type"
364
  | Record (p,l,true) ->
365
      (* let x = <...> in
366
	 { P.l1 = t1(Value.get_field x "l1"); ... } *)
367
368
369
370
371
      protect e 
      (fun x ->
	 let l = 
	   List.map
	     (fun (lab,t) ->
372
		(<:patt< $lid:p^lab$>>,
373
374
375
		 to_ml 
		 <:expr< Value.get_field $x$ $label_ascii lab$ >> t)) l in
	 <:expr< {$list:l$} >>)
376

377
378
379
  | Abstract "int" -> <:expr< Value.cduce2ocaml_int $e$ >>
  | Abstract "char" -> <:expr< Value.cduce2ocaml_char $e$ >>
  | Abstract "string" -> <:expr< Value.cduce2ocaml_string $e$ >>
380
  | Abstract s -> <:expr< Value.get_abstract $e$ >>
381
382
383
  | Builtin ("list",[t]) ->
      (* List.rev_map fun_t (Value.get_sequence_rev <...> *)
      <:expr< List.rev_map $lid:to_ml_fun t$ (Value.get_sequence_rev $e$) >>
384
385
386
387
388
  | Builtin ("Pervasives.ref",[t]) ->
      (* ref t(Eval.eval_apply (Value.get_field <...> "get") Value.nil)  *)
      let e = <:expr< Value.get_field $e$ $label_ascii "get"$ >> in
      let e = <:expr< Eval.eval_apply $e$ Value.nil >> in
      <:expr< Pervasives.ref $to_ml e t$ >>
389
390
  | Builtin ("CDuce_all.Value.t", []) -> e
  | Builtin ("unit", []) -> <:expr< ignore $e$ >>
391
  | Var _ -> e
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
  | _ -> assert false

and tuple_to_ml tl vars = List.map2 (fun t id -> to_ml <:expr< $lid:id$ >> t) tl vars


let to_ml_done = IntHash.create 13
let to_cd_done = IntHash.create 13

let global_transl () = 
  let defs = ref [] in
  let rec aux hd tl gen don fun_name to_descr =
    gen := tl;
    if not (IntHash.mem don hd.uid) then (
      IntHash.add don hd.uid ();
      let p = <:patt< $lid:fun_name hd$ >> in
      let e = <:expr< fun x -> $to_descr <:expr< x >> hd.def$ >> in
      defs := (p,e) :: !defs
    );
    loop ()
  and loop () = match !to_cd_gen,!to_ml_gen with
    | hd::tl,_ -> aux hd tl to_cd_gen to_cd_done to_cd_fun_name to_cd_descr
    | _,hd::tl -> aux hd tl to_ml_gen to_ml_done to_ml_fun_name to_ml_descr
    | [],[] -> ()
  in
  loop ();
  !defs

(* Check type constraints and generate stub code *)

421
422
let err_ppf = Format.err_formatter

423
424
let exts = ref []

425
let check_value ty_env c_env (s,caml_t,t) =
426
427
428
429
430
  (* Find the type for the value in the CDuce module *)
  let id = Id.mk (U.mk s) in
  let vt = 
    try Typer.find_value id ty_env
    with Not_found ->
431
432
      Format.fprintf err_ppf
      "The interface exports a value %s which is not available in the module@." s;
433
434
435
436
437
438
439
440
441
442
      exit 1
  in

  (* Compute expected CDuce type *)
  let et = Types.descr (typ t) in

  (* Check subtyping *)
  if not (Types.subtype vt et) then
    (
      Format.fprintf
443
444
445
446
447
       err_ppf
       "The type for the value %s is invalid@\n\
        Expected Caml type:@[%a@]@\n\
        Expected CDuce type:@[%a@]@\n\
        Inferred type:@[%a@]@."
448
       s
449
       print_ocaml caml_t
450
451
452
453
454
455
456
       Types.Print.print et
       Types.Print.print vt;
      exit 1
    );
   
  (* Generate stub code *)
  (* let x = t(Eval.get_slot cu slot) *)
457
  let x = mk_var () in
458
459
  let slot = Compile.find_slot id c_env in
  let e = to_ml <:expr< Eval.get_slot cu $int:string_of_int slot$ >> t in
460
  <:patt< $uid:s$ >>, <:expr< C.$uid:x$ >>, (<:patt< $uid:x$ >>, e)
461

462
let stub name ty_env c_env values =
463
  let items = List.map (check_value ty_env c_env) values in
464

465
  let exts = List.rev_map (fun (s,t) -> to_cd <:expr< $lid:s$ >> t) !exts in
466
467
  let g = global_transl () in

468
469
470
471
472
473
474
475
476
477
478
  (* 
     let (v1,v2,...,vn) = 
     let module C = struct
      let cu = ...
      open CDuce_all
      let types = ...
      let rec <global translation functions>
      <fills external slots>
      <run the unit>
      let <stubs for values>
     end in (C.x1,...,C.xn)
479
480
  *)

481
482
483
484
485
  let items_def = List.map (fun (_,_,d) -> d) items in
  let items_expr = List.map (fun (_,e,_) -> e)  items in
  let items_pat = List.map (fun (p,_,_) -> p) items in

  let m = 
486
    [ <:str_item< open CDuce_all >>;
487
488
      <:str_item< value types = Librarian.registered_types cu >> ] @
    (if g = [] then [] else [ <:str_item< value rec $list:g$ >> ]) @
489
    [ <:str_item< Librarian.set_externals cu [|$list:exts$|] >>;
490
491
492
493
494
495
496
497
498
    <:str_item< Librarian.run cu >> ] @
    (if items = [] then [] else [ <:str_item< value $list:items_def$ >> ]) in

  let items_expr = 
    match items_expr with 
      | [] -> <:expr< () >> 
      | l -> <:expr< ($list:l$) >> in

  <:patt< ($list:items_pat$) >>, m, items_expr
499
500


501
let register () =
502
503
504
505
  Librarian.stub_ml := 
  (fun cu ty_env c_env ->
     try
       let name = String.capitalize cu in
506
507
508
509
510
       let (prolog, values) = 
	 try Mltypes.read_cmi name
	 with Not_found ->  
	   Printf.eprintf "Warning: no caml interface\n";
	   ("",[]) in
511
512
513
       let code = stub cu ty_env c_env values in
       Some (Obj.magic (prolog,code)),
       get_registered_types ()
514
     with Mltypes.Error s -> raise (Location.Generic s)
515
516
  );

517
  Externals.register :=
518
519
  (fun i s args ->
     let (t,n) = 
520
521
522
523
524
       try Mltypes.find_value s 
       with Not_found ->
	 Printf.eprintf "Cannot resolve the external symbol %s\n" s;
	 exit 1
     in
525
526
527
528
529
530
     let m = List.length args in
     if n <> m then
       (
	 Printf.eprintf "Wrong arity for external symbol %s (real arity = %i; given = %i)\n" s n m;
	 exit 1
       );
531
     exts := (s, t) :: !exts;
532
533
534
535
536

     vars := Array.of_list args;
     let cdt = Types.descr (typ t) in
     vars := [| |];
     cdt
537
  )
538

539
540
541
542
543
let () =
  Config.register 
    "ocaml" 
    "OCaml interface" 
    register