From 03ce21b7df20dd1616f11f4d3e3f547b8ce97533 Mon Sep 17 00:00:00 2001 From: "Jonathan D.A. Jewell" <6759885+hyperpolymath@users.noreply.github.com> Date: Sat, 2 May 2026 23:40:23 +0100 Subject: [PATCH] fix(build): track 23 codegen modules referenced by lib/dune The lib/dune `(modules ...)` list on main references 23 codegen modules (bash_codegen, c_codegen, cuda_codegen, faust_codegen, gleam_codegen, js_codegen, lean_codegen, llvm_codegen, lua_codegen, metal_codegen, mlir_codegen, nickel_codegen, ocaml_codegen, onnx_codegen, onnx_proto, opencl_codegen, protobuf, rescript_codegen, rust_codegen, spirv_codegen, verilog_codegen, wgsl_codegen, why3_codegen) whose .ml files exist in the working tree but were never `git add`ed. CI's `dune build` fails on the first missing module: Error: Module Wgsl_codegen doesn't exist. Adding the 23 .ml files. No .mli companions exist; no opam dependency changes needed (each codegen uses only stdlib + already-declared deps: str unix sedlex fmt menhirLib yojson). Each file carries its own SPDX-License-Identifier and SPDX-FileCopyrightText header authored by the user; this commit only puts them under git tracking, no source modifications. Signed-off-by: Jonathan D.A. Jewell <6759885+hyperpolymath@users.noreply.github.com> Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Jonathan D.A. Jewell <6759885+hyperpolymath@users.noreply.github.com> --- lib/bash_codegen.ml | 106 +++++++++ lib/c_codegen.ml | 492 ++++++++++++++++++++++++++++++++++++++++ lib/cuda_codegen.ml | 159 +++++++++++++ lib/faust_codegen.ml | 252 ++++++++++++++++++++ lib/gleam_codegen.ml | 173 ++++++++++++++ lib/js_codegen.ml | 487 +++++++++++++++++++++++++++++++++++++++ lib/lean_codegen.ml | 114 ++++++++++ lib/llvm_codegen.ml | 258 +++++++++++++++++++++ lib/lua_codegen.ml | 233 +++++++++++++++++++ lib/metal_codegen.ml | 147 ++++++++++++ lib/mlir_codegen.ml | 191 ++++++++++++++++ lib/nickel_codegen.ml | 125 ++++++++++ lib/ocaml_codegen.ml | 250 ++++++++++++++++++++ lib/onnx_codegen.ml | 257 +++++++++++++++++++++ lib/onnx_proto.ml | 165 ++++++++++++++ lib/opencl_codegen.ml | 133 +++++++++++ lib/protobuf.ml | 142 ++++++++++++ lib/rescript_codegen.ml | 217 ++++++++++++++++++ lib/rust_codegen.ml | 273 ++++++++++++++++++++++ lib/spirv_codegen.ml | 154 +++++++++++++ lib/verilog_codegen.ml | 109 +++++++++ lib/wgsl_codegen.ml | 326 ++++++++++++++++++++++++++ lib/why3_codegen.ml | 126 ++++++++++ 23 files changed, 4889 insertions(+) create mode 100644 lib/bash_codegen.ml create mode 100644 lib/c_codegen.ml create mode 100644 lib/cuda_codegen.ml create mode 100644 lib/faust_codegen.ml create mode 100644 lib/gleam_codegen.ml create mode 100644 lib/js_codegen.ml create mode 100644 lib/lean_codegen.ml create mode 100644 lib/llvm_codegen.ml create mode 100644 lib/lua_codegen.ml create mode 100644 lib/metal_codegen.ml create mode 100644 lib/mlir_codegen.ml create mode 100644 lib/nickel_codegen.ml create mode 100644 lib/ocaml_codegen.ml create mode 100644 lib/onnx_codegen.ml create mode 100644 lib/onnx_proto.ml create mode 100644 lib/opencl_codegen.ml create mode 100644 lib/protobuf.ml create mode 100644 lib/rescript_codegen.ml create mode 100644 lib/rust_codegen.ml create mode 100644 lib/spirv_codegen.ml create mode 100644 lib/verilog_codegen.ml create mode 100644 lib/wgsl_codegen.ml create mode 100644 lib/why3_codegen.ml diff --git a/lib/bash_codegen.ml b/lib/bash_codegen.ml new file mode 100644 index 0000000..4841975 --- /dev/null +++ b/lib/bash_codegen.ml @@ -0,0 +1,106 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Bash/POSIX shell emitter (heavily restricted MVP). + + Bash is genuinely limited: no first-class numbers, no closures, no + structures. We accept Int-only programs whose entry returns Int (used + as the shell exit code). Functions are emitted as Bash functions that + [echo] their result; callers read it via [$(fn args)]. *) + +open Ast + +exception Bash_unsupported of string +let unsupported m = raise (Bash_unsupported m) + +let bash_reserved = [ + "if"; "then"; "else"; "elif"; "fi"; "case"; "esac"; + "for"; "while"; "until"; "do"; "done"; "function"; "return"; + "in"; "break"; "continue"; "exit"; "true"; "false"; +] + +let mangle s = if List.mem s bash_reserved then s ^ "_" else s + +(* Bash arithmetic only: every expression must evaluate to an integer. *) +let rec gen_expr (e : expr) : string = + match e with + | ExprLit (LitInt (n, _)) -> string_of_int n + | ExprLit (LitBool (true, _)) -> "1" + | ExprLit (LitBool (false, _))-> "0" + | ExprLit _ -> unsupported "Bash backend accepts only Int/Bool literals" + | ExprVar id -> "$" ^ mangle id.name + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> unsupported "string concat not supported in Bash backend" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (OpNeg, x) -> "(-" ^ gen_expr x ^ ")" + | ExprUnary (OpNot, x) -> "(! " ^ gen_expr x ^ ")" + | ExprUnary _ -> unsupported "unary op not supported in Bash backend" + | ExprApp (callee, args) -> + let name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect call not supported in Bash backend" + in + let arg_strs = List.map gen_expr args in + (* Call a Bash function and capture its echo: $(fn 1 2) *) + Printf.sprintf "$(%s %s)" (mangle name) (String.concat " " arg_strs) + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "0" in + (* Bash arithmetic ternary: $((c ? t : f)) — works in all $((...)) *) + Printf.sprintf "(%s ? %s : %s)" c t f + | ExprLet { el_pat = PatVar _; _ } -> + unsupported "expression-position let not supported in Bash backend" + | ExprSpan (inner, _) -> gen_expr inner + | ExprBlock _ -> unsupported "block expressions not supported in Bash backend" + | _ -> unsupported "expression form not supported in Bash backend" + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let params = fd.fd_params in + let body_expr = match fd.fd_body with + | FnExpr e -> e + | FnBlock { blk_stmts = []; blk_expr = Some e } -> e + | FnBlock _ -> + unsupported "Bash backend accepts only single-expression function bodies" + in + let body_str = gen_expr body_expr in + let buf = Buffer.create 128 in + Buffer.add_string buf (name ^ "() {\n"); + List.iteri (fun i (p : param) -> + Buffer.add_string buf + (Printf.sprintf " local %s=$%d\n" (mangle p.p_name.name) (i + 1)) + ) params; + Buffer.add_string buf (Printf.sprintf " echo $((%s))\n}\n\n" body_str); + Buffer.contents buf + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "#!/usr/bin/env bash\n"; + Buffer.add_string buf "# Generated by AffineScript compiler\n"; + Buffer.add_string buf "# SPDX-License-Identifier: PMPL-1.0-or-later\n"; + Buffer.add_string buf "set -euo pipefail\n\n"; + List.iter (function + | TopFn fd -> Buffer.add_string buf (gen_function fd) + | _ -> () + ) program.prog_decls; + let has_main = List.exists (function + | TopFn fd -> fd.fd_name.name = "main" + | _ -> false) program.prog_decls in + if has_main then Buffer.add_string buf "exit $(main)\n"; + Buffer.contents buf + +let codegen_bash (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Bash_unsupported msg -> Error ("Bash backend: " ^ msg) + | Failure msg -> Error ("Bash codegen error: " ^ msg) + | e -> Error ("Bash codegen error: " ^ Printexc.to_string e) diff --git a/lib/c_codegen.ml b/lib/c_codegen.ml new file mode 100644 index 0000000..fcfa9ea --- /dev/null +++ b/lib/c_codegen.ml @@ -0,0 +1,492 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** C Code Generator (MVP). + + Translates a typed AffineScript program to a single self-contained C99 + source file. The generated file links nothing beyond libc and a small + inline runtime emitted at the top of every output, so the round-trip is + + affinescript compile foo.affine -o foo.c + cc foo.c -o foo && ./foo + + Phase 1 (this file): functions, primitive arithmetic, control flow, let, + string println, simple match. Tuples / records / variants / ownership are + not lowered — they emit an explicit error stub so a regression is loud + rather than silent. + + Compatibility: relies on GCC/Clang "statement expressions" ({ ... }) so + block expressions can appear inside larger expressions. This is the same + trade the WASM backend implicitly makes (its blocks lower to wasm blocks). + Both gcc and clang accept it; tcc does too. msvc does not. +*) + +open Ast + +(* ============================================================================ + Code Generation Context + ============================================================================ *) + +type codegen_ctx = { + output : Buffer.t; + indent : int; + symbols : Symbol.t; + fwd_decls : Buffer.t; (* Forward declarations, written before bodies. *) +} + +let create_ctx symbols = { + output = Buffer.create 1024; + indent = 0; + symbols; + fwd_decls = Buffer.create 256; +} + +let emit ctx str = Buffer.add_string ctx.output str + +let emit_line ctx str = + let spaces = String.make (ctx.indent * 4) ' ' in + Buffer.add_string ctx.output spaces; + Buffer.add_string ctx.output str; + Buffer.add_char ctx.output '\n' + +let increase_indent ctx = { ctx with indent = ctx.indent + 1 } +let decrease_indent ctx = { ctx with indent = max 0 (ctx.indent - 1) } + +(* ============================================================================ + Runtime prelude + + Inlined into every output so generated code links against libc only. + ============================================================================ *) + +let prelude = {|/* ---- AffineScript C runtime (MVP) ---- */ +#include +#include +#include +#include + +typedef long as_int_t; +typedef double as_float_t; +typedef int as_bool_t; +typedef const char *as_str_t; + +static inline void print(as_str_t s) { fputs(s, stdout); } +static inline void println(as_str_t s) { puts(s); } +/* ---- end runtime ---- */ + +|} + +(* ============================================================================ + Identifier sanitisation + ============================================================================ *) + +let c_reserved = [ + "auto"; "break"; "case"; "char"; "const"; "continue"; "default"; "do"; + "double"; "else"; "enum"; "extern"; "float"; "for"; "goto"; "if"; "inline"; + "int"; "long"; "register"; "restrict"; "return"; "short"; "signed"; + "sizeof"; "static"; "struct"; "switch"; "typedef"; "union"; "unsigned"; + "void"; "volatile"; "while"; "_Bool"; "_Complex"; "_Imaginary"; + (* runtime collisions *) + "main"; "exit"; "abort"; "free"; "malloc"; "calloc"; "realloc"; + "printf"; "puts"; "fputs"; "stdin"; "stdout"; "stderr"; +] + +let mangle (name : string) : string = + if List.mem name c_reserved then name ^ "_" else name + +(* ============================================================================ + Type lowering + + AS type -> concrete C type. Anything unknown becomes `void *` so the + generated code at least parses; the WASM backend remains the source of + truth for full type-driven lowering. + ============================================================================ *) + +let rec c_type_of_ty (te : type_expr) : string = + match te with + | TyCon name when name.name = "Int" -> "as_int_t" + | TyCon name when name.name = "Float" -> "as_float_t" + | TyCon name when name.name = "Bool" -> "as_bool_t" + | TyCon name when name.name = "String" -> "as_str_t" + | TyCon name when name.name = "Unit" -> "void" + | TyCon name -> mangle name.name (* user-declared typedef *) + | TyApp _ -> "void *" + | TyArrow (_, _, _, _) -> "void *" + | TyTuple _ | TyRecord _ -> "void *" + | TyOwn t | TyRef t | TyMut t -> c_type_of_ty t + | TyVar _ | TyHole -> "void *" + +let c_type_of_ret = function + | None -> "void" + | Some ty -> c_type_of_ty ty + +(* ============================================================================ + Expression Code Generation + + Returned strings are valid C expressions. Statement-shaped constructs use + GCC statement expressions: ({ stmt; stmt; expr; }). + ============================================================================ *) + +let rec gen_expr ctx (expr : expr) : string = + match expr with + | ExprLit lit -> gen_literal lit + | ExprVar name -> mangle name.name + | ExprApp (func, args) -> + let func_str = gen_expr ctx func in + let arg_strs = List.map (gen_expr ctx) args in + func_str ^ "(" ^ String.concat ", " arg_strs ^ ")" + | ExprBinary (e1, op, e2) -> + let op_str = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" + | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> + (* C has no string-concat operator; defer to a runtime call. + The runtime stub is intentionally absent so a use-site shows + up at link time rather than producing wrong output. *) + "@CONCAT@" + in + if op = OpConcat then + Printf.sprintf "as_concat(%s, %s)" (gen_expr ctx e1) (gen_expr ctx e2) + else + "(" ^ gen_expr ctx e1 ^ " " ^ op_str ^ " " ^ gen_expr ctx e2 ^ ")" + | ExprUnary (op, e) -> + (match op with + | OpNeg -> "(-" ^ gen_expr ctx e ^ ")" + | OpNot -> "(!" ^ gen_expr ctx e ^ ")" + | OpBitNot -> "(~" ^ gen_expr ctx e ^ ")" + | OpRef -> "(&" ^ gen_expr ctx e ^ ")" + | OpDeref -> "(*" ^ gen_expr ctx e ^ ")") + | ExprIf { ei_cond; ei_then; ei_else } -> + let cond_str = gen_expr ctx ei_cond in + let then_str = gen_expr ctx ei_then in + let else_str = match ei_else with + | Some e -> gen_expr ctx e + | None -> "((void)0)" + in + "(" ^ cond_str ^ " ? " ^ then_str ^ " : " ^ else_str ^ ")" + | ExprLet { el_pat; el_value; el_body; el_mut = _; el_quantity = _; el_ty } -> + let var = match el_pat with + | PatVar id -> mangle id.name + | PatWildcard _ -> "_unused" + | _ -> "_unsupported_pat" + in + let ty_str = match el_ty with + | Some t -> c_type_of_ty t + | None -> "long" (* MVP: untyped binders default to long *) + in + let val_str = gen_expr ctx el_value in + (match el_body with + | Some body -> + let body_str = gen_expr ctx body in + Printf.sprintf "({ %s %s = %s; %s; })" ty_str var val_str body_str + | None -> + Printf.sprintf "({ %s %s = %s; (void)0; })" ty_str var val_str) + | ExprBlock block -> gen_block_expr ctx block + | ExprReturn (Some e) -> + Printf.sprintf "({ return %s; })" (gen_expr ctx e) + | ExprReturn None -> + "({ return; })" + | ExprMatch { em_scrutinee; em_arms } -> + gen_match ctx em_scrutinee em_arms + | ExprField (record, field) -> + gen_expr ctx record ^ "." ^ mangle field.name + | ExprTupleIndex (e, n) -> + Printf.sprintf "(%s).f%d" (gen_expr ctx e) n + | ExprIndex (arr, idx) -> + Printf.sprintf "(%s)[%s]" (gen_expr ctx arr) (gen_expr ctx idx) + | ExprSpan (inner, _) -> gen_expr ctx inner + | ExprHandle { eh_body; eh_handlers = _ } -> gen_expr ctx eh_body + | ExprResume (Some e) -> gen_expr ctx e + | ExprResume None -> "((void)0)" + | ExprRecord { er_fields; _ } -> + (* Emit just the designated-initializer brace block; the surrounding + context (gen_let_with_hint) supplies the [(Type)] cast. *) + let fs = List.map (fun (id, e_opt) -> + let v = match e_opt with Some e -> gen_expr ctx e | None -> mangle id.name in + Printf.sprintf ".%s = %s" (mangle id.name) v + ) er_fields in + "{ " ^ String.concat ", " fs ^ " }" + | ExprVariant (_ty, ctor) -> mangle ctor.name (* refs a generated constant or fn *) + | ExprTuple _ | ExprArray _ | ExprLambda _ | ExprTry _ + | ExprRowRestrict _ | ExprUnsafe _ -> + "(__as_unsupported_expr_for_c_backend())" + +and gen_literal (lit : literal) : string = + match lit with + | LitInt (n, _) -> "((as_int_t)" ^ string_of_int n ^ ")" + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "1" + | LitBool (false, _) -> "0" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "'" ^ Char.escaped c ^ "'" + | LitUnit _ -> "((void)0)" + +and gen_block_expr ctx block = + (* Emit ({ stmt; stmt; expr; }) — GCC statement expression. *) + let buf = Buffer.create 64 in + List.iter (fun s -> + Buffer.add_string buf (gen_stmt ctx s); + Buffer.add_char buf ' ' + ) block.blk_stmts; + let tail = match block.blk_expr with + | Some e -> gen_expr ctx e ^ ";" + | None -> "((void)0);" + in + "({ " ^ Buffer.contents buf ^ tail ^ " })" + +and gen_match ctx scrutinee arms = + (* Lowered to a statement-expression: bind the scrutinee, then walk arms. + Each arm becomes a guarded block that, on tag/literal match, binds + pattern-variables from the union member and yields the body value. *) + let scrut_str = gen_expr ctx scrutinee in + let arm_strs = List.map (fun arm -> + match arm.ma_pat with + | PatWildcard _ | PatVar _ -> + Printf.sprintf "{ __as_match_result = (%s); break; }" (gen_expr ctx arm.ma_body) + | PatLit lit -> + Printf.sprintf "if (__scrut == %s) { __as_match_result = (%s); break; }" + (gen_literal lit) (gen_expr ctx arm.ma_body) + | PatCon (id, args) -> + let cond = Printf.sprintf "__scrut.tag == TAG_%s" (mangle id.name) in + let bindings = List.mapi (fun i p -> + match p with + | PatVar pid -> + Printf.sprintf "%s %s = __scrut.u.%s.f%d;" + "long" (mangle pid.name) (mangle id.name) i + | _ -> "" + ) args |> String.concat " " in + Printf.sprintf "if (%s) { %s __as_match_result = (%s); break; }" + cond bindings (gen_expr ctx arm.ma_body) + | _ -> + Printf.sprintf "{ __as_match_result = (%s); break; }" (gen_expr ctx arm.ma_body) + ) arms in + Printf.sprintf + "({ __typeof__(%s) __scrut = %s; long __as_match_result = 0; do { %s } while (0); __as_match_result; })" + scrut_str scrut_str (String.concat " " arm_strs) + +and gen_stmt ctx (stmt : stmt) : string = + match stmt with + | StmtLet { sl_pat; sl_value; sl_mut = _; sl_quantity = _; sl_ty } -> + let var = match sl_pat with + | PatVar id -> mangle id.name + | PatWildcard _ -> "_unused" + | _ -> "_unsupported_pat" + in + let ty_str = match sl_ty with + | Some t -> c_type_of_ty t + | None -> "long" + in + (* Record literals need a `(Type)` cast in C. When the let has a type + annotation pointing at a typedef, prepend it so designated braces + parse as a compound literal. *) + let value_str = + match sl_value, sl_ty with + | ExprRecord _, Some (TyCon id) -> + Printf.sprintf "(%s)%s" (mangle id.name) (gen_expr ctx sl_value) + | _ -> gen_expr ctx sl_value + in + Printf.sprintf "%s %s = %s;" ty_str var value_str + | StmtExpr e -> + gen_expr ctx e ^ ";" + | StmtAssign (lhs, op, rhs) -> + let op_str = match op with + | AssignEq -> "=" | AssignAdd -> "+=" + | AssignSub -> "-=" | AssignMul -> "*=" + | AssignDiv -> "/=" + in + Printf.sprintf "%s %s %s;" (gen_expr ctx lhs) op_str (gen_expr ctx rhs) + | StmtWhile (cond, body) -> + let body_strs = List.map (gen_stmt ctx) body.blk_stmts in + let tail = match body.blk_expr with + | Some e -> gen_expr ctx e ^ ";" + | None -> "" + in + Printf.sprintf "while (%s) { %s %s }" + (gen_expr ctx cond) (String.concat " " body_strs) tail + | StmtFor (_pat, _iter, _body) -> + (* MVP: AS for-in over iterators has no direct C analogue. Emit a + placeholder so the build link-fails cleanly. *) + "{ __as_unsupported_for_loop(); }" + +(* ============================================================================ + Top-Level Declaration Code Generation + ============================================================================ *) + +let gen_function ctx (fd : fn_decl) : unit = + let name = mangle fd.fd_name.name in + let ret_ty = c_type_of_ret fd.fd_ret_ty in + let params = List.map (fun (p : param) -> + Printf.sprintf "%s %s" (c_type_of_ty p.p_ty) (mangle p.p_name.name) + ) fd.fd_params in + let params_str = + if params = [] then "void" else String.concat ", " params + in + let signature = Printf.sprintf "%s %s(%s)" ret_ty name params_str in + + (* Forward declaration so any-order calls work. *) + Buffer.add_string ctx.fwd_decls (signature ^ ";\n"); + + emit_line ctx (signature ^ " {"); + let inner = increase_indent ctx in + (match fd.fd_body with + | FnExpr body_expr -> + if ret_ty = "void" then + emit_line inner ((gen_expr inner body_expr) ^ ";") + else + emit_line inner ("return " ^ gen_expr inner body_expr ^ ";") + | FnBlock block -> + List.iter (fun s -> emit_line inner (gen_stmt inner s)) block.blk_stmts; + (match block.blk_expr with + | Some e -> + if ret_ty = "void" then + emit_line inner (gen_expr inner e ^ ";") + else + emit_line inner ("return " ^ gen_expr inner e ^ ";") + | None -> ())); + emit_line ctx "}"; + emit ctx "\n" + +let emit_struct_decl ctx (name : string) (fields : (string * type_expr) list) : unit = + let lines = List.map (fun (n, ty) -> + Printf.sprintf " %s %s;" (c_type_of_ty ty) (mangle n)) fields in + emit_line ctx (Printf.sprintf "typedef struct {\n%s\n} %s;" + (String.concat "\n" lines) name); + emit ctx "\n" + +let emit_enum_decl ctx (name : string) (variants : variant_decl list) : unit = + (* tag enum *) + let tags = List.map (fun (vd : variant_decl) -> + "TAG_" ^ mangle vd.vd_name.name) variants in + emit_line ctx (Printf.sprintf "typedef enum { %s } %s_tag;" + (String.concat ", " tags) name); + (* tagged union *) + let union_members = List.map (fun (vd : variant_decl) -> + let payload = + if vd.vd_fields = [] then "char _unit;" + else + String.concat " " + (List.mapi (fun i ty -> + Printf.sprintf "%s f%d;" (c_type_of_ty ty) i) vd.vd_fields) + in + Printf.sprintf " struct { %s } %s;" payload (mangle vd.vd_name.name) + ) variants in + emit_line ctx (Printf.sprintf "typedef struct {"); + emit_line ctx (Printf.sprintf " %s_tag tag;" name); + emit_line ctx (Printf.sprintf " union {"); + List.iter (emit_line ctx) union_members; + emit_line ctx (Printf.sprintf " } u;"); + emit_line ctx (Printf.sprintf "} %s;" name); + (* constructor functions / constants *) + List.iter (fun (vd : variant_decl) -> + let cname = mangle vd.vd_name.name in + let arity = List.length vd.vd_fields in + if arity = 0 then + emit_line ctx + (Printf.sprintf "static const %s %s = (%s){ .tag = TAG_%s };" + name cname name cname) + else begin + let params = List.mapi (fun i ty -> + Printf.sprintf "%s f%d" (c_type_of_ty ty) i) vd.vd_fields in + let inits = List.mapi (fun i _ -> Printf.sprintf ".f%d = f%d" i i) vd.vd_fields in + emit_line ctx + (Printf.sprintf + "static inline %s %s(%s) { return (%s){ .tag = TAG_%s, .u.%s = { %s } }; }" + name cname (String.concat ", " params) + name cname cname (String.concat ", " inits)) + end + ) variants; + emit ctx "\n" + +let gen_type_decl_c ctx (td : type_decl) : unit = + let name = mangle td.td_name.name in + match td.td_body with + | TyAlias (TyRecord (fields, _)) -> + let pairs = List.map (fun (rf : row_field) -> (rf.rf_name.name, rf.rf_ty)) fields in + emit_struct_decl ctx name pairs + | TyAlias t -> + emit_line ctx (Printf.sprintf "typedef %s %s;" (c_type_of_ty t) name) + | TyStruct fields -> + let pairs = List.map (fun (sf : struct_field) -> (sf.sf_name.name, sf.sf_ty)) fields in + emit_struct_decl ctx name pairs + | TyEnum variants -> + emit_enum_decl ctx name variants + +let gen_top_level ctx (top : top_level) : unit = + match top with + | TopFn fd -> gen_function ctx fd + | TopType td -> gen_type_decl_c ctx td + | TopConst { tc_name; tc_ty; tc_value; _ } -> + emit_line ctx + (Printf.sprintf "static const %s %s = %s;" + (c_type_of_ty tc_ty) + (mangle tc_name.name) + (gen_expr ctx tc_value)) + | TopEffect _ -> emit_line ctx "/* effect declaration (erased) */" + | TopTrait _ -> emit_line ctx "/* trait declaration (erased) */" + | TopImpl _ -> emit_line ctx "/* impl block (erased) */" + +(* ============================================================================ + Driver + + AffineScript's `main` returns Int but C's `main` returns `int`. If the + program defines `main`, emit a C `main` that calls it and propagates the + exit code. If `main` returns Unit, exit 0. + ============================================================================ *) + +let main_entry_for (program : program) : string = + let main_fn = List.find_map (function + | TopFn fd when fd.fd_name.name = "main" -> Some fd + | _ -> None + ) program.prog_decls in + match main_fn with + | None -> "" + | Some fd -> + let ret_ty = c_type_of_ret fd.fd_ret_ty in + if ret_ty = "void" then + "int main(void) { main_(); return 0; }\n" + else + Printf.sprintf "int main(void) { return (int)main_(); }\n" + +let generate (program : program) (symbols : Symbol.t) : string = + let ctx = create_ctx symbols in + emit_line ctx "/* Generated by AffineScript compiler */"; + emit_line ctx "/* SPDX-License-Identifier: PMPL-1.0-or-later */"; + emit ctx prelude; + + (* Three-pass emission so forward declarations and body code see all + types: (1) type decls (typedefs, tagged unions, ctor inlines) into + types_buf; (2) function bodies into bodies_buf, accumulating fn + forward decls into fwd_decls. Final layout is types → fwd → bodies. *) + let types_buf = Buffer.create 512 in + let bodies_buf = Buffer.create 1024 in + let types_ctx = { ctx with output = types_buf } in + let body_ctx = { ctx with output = bodies_buf } in + List.iter (function + | TopType td -> gen_type_decl_c types_ctx td + | _ -> () + ) program.prog_decls; + List.iter (function + | TopType _ -> () + | other -> gen_top_level body_ctx other + ) program.prog_decls; + + Buffer.add_buffer ctx.output types_buf; + Buffer.add_char ctx.output '\n'; + Buffer.add_buffer ctx.output ctx.fwd_decls; + Buffer.add_char ctx.output '\n'; + Buffer.add_buffer ctx.output bodies_buf; + Buffer.add_string ctx.output (main_entry_for program); + + Buffer.contents ctx.output + +let codegen_c (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("C codegen error: " ^ msg) + | e -> Error ("C codegen error: " ^ Printexc.to_string e) diff --git a/lib/cuda_codegen.ml b/lib/cuda_codegen.ml new file mode 100644 index 0000000..68fc727 --- /dev/null +++ b/lib/cuda_codegen.ml @@ -0,0 +1,159 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** CUDA C++ kernel sublanguage emitter (MVP). + + Same kernel shape as the WGSL backend: first param is the global index, + remaining params are buffers. Lowers to a [__global__ void] function + plus a host wrapper that the user can call from C++. *) + +open Ast + +exception Cuda_unsupported of string +let unsupported m = raise (Cuda_unsupported m) + +let mangle s = s + +let scalar_of_type_name = function + | "Int" -> "int" + | "Float" -> "float" + | "Bool" -> "bool" + | n -> unsupported ("type not allowed in CUDA kernel: " ^ n) + +let rec scalar_of (te : type_expr) : string = + match te with + | TyCon id -> scalar_of_type_name id.name + | TyOwn t | TyRef t | TyMut t -> scalar_of t + | _ -> unsupported "complex type not allowed in CUDA kernel" + +let array_element (te : type_expr) : string = + let rec strip = function + | TyOwn t | TyRef t | TyMut t -> strip t + | t -> t + in + match strip te with + | TyApp (id, [TyArg inner]) when id.name = "Array" -> scalar_of inner + | _ -> unsupported "expected Array[Int|Float] for kernel buffer" + +let const_qual = function + | Some Mut -> "" + | _ -> "const " + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> unsupported "concat not supported in CUDA" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (OpNeg, x) -> "(-" ^ gen_expr x ^ ")" + | ExprUnary (OpNot, x) -> "(!" ^ gen_expr x ^ ")" + | ExprUnary (OpBitNot, x) -> "(~" ^ gen_expr x ^ ")" + | ExprUnary _ -> unsupported "unary op not supported in CUDA kernel" + | ExprIf { ei_cond; ei_then; ei_else } -> + let f = match ei_else with Some e -> gen_expr e | None -> "0" in + Printf.sprintf "(%s ? %s : %s)" (gen_expr ei_cond) (gen_expr ei_then) f + | ExprIndex (a, i) -> Printf.sprintf "%s[%s]" (gen_expr a) (gen_expr i) + | ExprApp (callee, args) -> + let name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect call" + in + let known = ["sin"; "cos"; "tan"; "sqrt"; "exp"; "log"; "pow"; + "fabs"; "floor"; "ceil"; "min"; "max"; "tanh"] in + if not (List.mem name known) then + unsupported ("call to non-builtin in CUDA kernel: " ^ name); + Printf.sprintf "%s(%s)" name + (String.concat ", " (List.map gen_expr args)) + | ExprSpan (inner, _) -> gen_expr inner + | _ -> unsupported "expression form not supported in CUDA kernel" + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + let s = if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s in + s ^ "f" + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | _ -> unsupported "literal form not supported in CUDA kernel" + +let rec gen_stmt (s : stmt) : string = + match s with + | StmtLet { sl_pat = PatVar id; sl_value; sl_ty; _ } -> + let ty = match sl_ty with Some t -> scalar_of t | None -> "int" in + Printf.sprintf "%s %s = %s;" ty (mangle id.name) (gen_expr sl_value) + | StmtLet _ -> unsupported "destructuring let not supported in CUDA" + | StmtAssign (lhs, op, rhs) -> + let s = match op with + | AssignEq -> "=" | AssignAdd -> "+=" | AssignSub -> "-=" + | AssignMul -> "*=" | AssignDiv -> "/=" in + Printf.sprintf "%s %s %s;" (gen_expr lhs) s (gen_expr rhs) + | StmtExpr e -> gen_expr e ^ ";" + | StmtWhile (c, b) -> + Printf.sprintf "while (%s) { %s }" (gen_expr c) + (String.concat " " (List.map gen_stmt b.blk_stmts)) + | StmtFor _ -> unsupported "for-in not supported in CUDA kernel" + +let pick_kernel (program : program) : fn_decl = + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) program.prog_decls in + match List.find_opt (fun fd -> fd.fd_name.name = "kernel") fns with + | Some fd -> fd + | None -> match fns with + | fd :: _ -> fd + | [] -> unsupported "no function found" + +let validate_kernel (fd : fn_decl) : unit = + match fd.fd_params with + | [] -> unsupported "kernel must take an Int index parameter" + | first :: _ -> + match first.p_ty with + | TyCon id when id.name = "Int" -> () + | _ -> unsupported "first param must be Int" + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler (CUDA C++)\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + Buffer.add_string buf "#include \n\n"; + let fd = pick_kernel program in + validate_kernel fd; + let idx = match fd.fd_params with first :: _ -> first.p_name.name | _ -> "i" in + let bufs = match fd.fd_params with _ :: rest -> rest | [] -> [] in + let buf_decls = List.map (fun (p : param) -> + Printf.sprintf "%s%s *%s" + (const_qual p.p_ownership) (array_element p.p_ty) p.p_name.name + ) bufs in + Buffer.add_string buf "__global__\n"; + Buffer.add_string buf + (Printf.sprintf "void %s(%s) {\n" (mangle fd.fd_name.name) + (String.concat ", " buf_decls)); + Buffer.add_string buf + (Printf.sprintf " int %s = blockIdx.x * blockDim.x + threadIdx.x;\n" idx); + (match fd.fd_body with + | FnExpr e -> + Buffer.add_string buf (Printf.sprintf " (void)(%s);\n" (gen_expr e)) + | FnBlock b -> + List.iter (fun s -> + Buffer.add_string buf (" " ^ gen_stmt s ^ "\n") + ) b.blk_stmts; + (match b.blk_expr with + | Some e -> Buffer.add_string buf (Printf.sprintf " (void)(%s);\n" (gen_expr e)) + | None -> ())); + Buffer.add_string buf "}\n"; + Buffer.contents buf + +let codegen_cuda (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Cuda_unsupported m -> Error ("CUDA backend: " ^ m) + | Failure m -> Error ("CUDA codegen error: " ^ m) + | e -> Error ("CUDA codegen error: " ^ Printexc.to_string e) diff --git a/lib/faust_codegen.ml b/lib/faust_codegen.ml new file mode 100644 index 0000000..ad7e27b --- /dev/null +++ b/lib/faust_codegen.ml @@ -0,0 +1,252 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Faust DSP Sublanguage Emitter (MVP). + + Lowers a strict subset of AffineScript to the Faust audio-DSP language. + + Source shape (the kernel sublanguage): + - one or more [fn] declarations, all over [Float] + - one of them is the entry point — by convention named [process] or + [main]; otherwise the first [fn] in the file + - parameters are scalar [Float] (audio samples or controls) + - return type is [Float] + - body uses arithmetic, comparison, [if]/[else] (lowered to [select2]), + [let] bindings (folded into Faust [with { ... }] locals), and calls to + a whitelist of Faust built-ins ([sin], [cos], [tanh], ...) + + Output is a single [.dsp] file consumable by the [faust] compiler, which + can then be re-targeted to C++ / WebAudio / JUCE / VST / Csound / + LV2 etc. — i.e. this MVP buys all of those targets at once. +*) + +open Ast + +exception Faust_unsupported of string +let unsupported msg = raise (Faust_unsupported msg) + +(* ============================================================================ + Identifier sanitisation + ============================================================================ *) + +(* Actual Faust reserved keywords. [process] / [main] are NOT keywords — + they are entry-point conventions and the entry function is always emitted + under the name [process] regardless of what the source called it. *) +let faust_reserved = [ + "with"; "letrec"; "case"; "import"; "library"; "declare"; + "environment"; "component"; "ffunction"; "fconstant"; "fvariable"; + "where"; "of"; +] + +let mangle s = + if List.mem s faust_reserved then "as_" ^ s else s + +(* ============================================================================ + Type validation + + Faust is essentially monomorphic Float (with int/float subtype distinction + that the compiler manages). We accept Float and Int (which Faust auto- + coerces) and reject everything else. + ============================================================================ *) + +let rec scalar_ok (te : type_expr) : unit = + match te with + | TyCon id when id.name = "Float" || id.name = "Int" -> () + | TyOwn t | TyRef t | TyMut t -> scalar_ok t + | _ -> unsupported "Faust kernels accept only Int/Float scalars" + +(* ============================================================================ + Expressions + + Faust expressions are evaluated as signal flow but written infix. + We emit them as parenthesised trees identical in shape to the AST. + ============================================================================ *) + +let faust_builtins = [ + "sin"; "cos"; "tan"; "asin"; "acos"; "atan"; "atan2"; + "exp"; "log"; "log10"; "sqrt"; "pow"; "floor"; "ceil"; "round"; + "abs"; "min"; "max"; "fmod"; "tanh"; "sinh"; "cosh"; + "int"; "float"; (* type coercions in Faust *) +] + +(* User-defined function names visible at codegen time. Populated once per + [generate] call before any [gen_expr] is invoked, so inter-fn calls are + recognised. Module-scoped because gen_expr is otherwise a pure tree + transformation; threading a context through every node would inflate the + diff without buying anything in single-shot codegen. *) +let user_fns : string list ref = ref [] + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&" | OpOr -> "|" (* Faust uses bitwise tokens for both *) + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "xor" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> unsupported "string/array concat not supported in Faust" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (op, x) -> + (match op with + | OpNeg -> "(0 - " ^ gen_expr x ^ ")" (* Faust has no unary minus *) + | OpNot -> "(1 - " ^ gen_expr x ^ ")" (* boolean as 0/1 *) + | OpBitNot -> unsupported "bitwise not not supported in Faust" + | OpRef | OpDeref -> unsupported "ref/deref not supported in Faust") + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with + | Some e -> gen_expr e + | None -> unsupported "if without else cannot lower to Faust select2" + in + Printf.sprintf "select2(%s, %s, %s)" c f t + | ExprApp (callee, args) -> + let name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect calls not supported in Faust" + in + let emit_name = + if List.mem name faust_builtins then name + else if List.mem name !user_fns then mangle name + else unsupported ("call to non-builtin in Faust kernel: " ^ name) + in + Printf.sprintf "%s(%s)" emit_name + (String.concat ", " (List.map gen_expr args)) + | ExprLet { el_pat; el_value; el_body; el_mut = _; el_quantity = _; el_ty = _ } -> + (* Faust's [with { var = expr; }] clause attaches local definitions to + a parent expression. We emit it inline so: + let v = e1 in e2 ↦ (e2) with { v = e1; } *) + let var = match el_pat with + | PatVar id -> mangle id.name + | _ -> unsupported "non-variable let binding not supported in Faust" + in + let body = match el_body with + | Some e -> gen_expr e + | None -> unsupported "statement-position let cannot stand alone in Faust" + in + Printf.sprintf "(%s) with { %s = %s; }" body var (gen_expr el_value) + | ExprBlock blk -> gen_block blk + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | ExprMatch _ -> unsupported "match not supported in Faust kernel" + | ExprLambda _ -> unsupported "lambdas not supported in Faust" + | ExprTuple _ | ExprArray _ | ExprRecord _ + | ExprField _ | ExprTupleIndex _ | ExprIndex _ | ExprRowRestrict _ -> + unsupported "compound values not supported in Faust kernel" + | ExprReturn None | ExprTry _ | ExprHandle _ + | ExprResume _ | ExprUnsafe _ | ExprVariant _ -> + unsupported "control-flow construct not supported in Faust kernel" + +and gen_lit (lit : literal) : string = + match lit with + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "1" + | LitBool (false, _) -> "0" + | LitChar _ -> unsupported "char literals not supported in Faust" + | LitString _ -> unsupported "string literals not supported in Faust" + | LitUnit _ -> unsupported "unit literal not supported in Faust" + +and gen_block (blk : block) : string = + (* A block becomes a chain of [with { ... }] locals followed by the trailing + expression. Empty blocks aren't meaningful — Faust requires every + definition to produce a value. *) + let result = match blk.blk_expr with + | Some e -> e + | None -> unsupported "block without trailing expression cannot be lowered" + in + (* Statements turn into local definitions, applied in reverse so + earlier bindings are visible to later ones via Faust's scope rules. *) + let withs = List.filter_map (fun s -> + match s with + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + Some (Printf.sprintf "%s = %s;" (mangle id.name) (gen_expr sl_value)) + | StmtLet _ -> unsupported "destructuring let not supported in Faust" + | StmtExpr _ -> unsupported "statement-form expression not supported in Faust" + | StmtAssign _ -> unsupported "assignment not supported in Faust (signals are immutable)" + | StmtWhile _ | StmtFor _ -> + unsupported "imperative loops not supported in Faust" + ) blk.blk_stmts in + if withs = [] then gen_expr result + else + Printf.sprintf "(%s) with { %s }" (gen_expr result) (String.concat " " withs) + +(* ============================================================================ + Top-level + ============================================================================ *) + +let validate_kernel_fn (fd : fn_decl) : unit = + (match fd.fd_ret_ty with + | None -> unsupported "kernel function must return Float" + | Some t -> scalar_ok t); + List.iter (fun (p : param) -> scalar_ok p.p_ty) fd.fd_params + +(* Emit a Faust function definition. [as_entry] forces the emitted name to + be [process] (the Faust entry point) regardless of the source name. *) +let gen_function ?(as_entry = false) (fd : fn_decl) : string = + validate_kernel_fn fd; + let name = if as_entry then "process" else mangle fd.fd_name.name in + let params = List.map (fun (p : param) -> mangle p.p_name.name) fd.fd_params in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + if params = [] then + Printf.sprintf "%s = %s;\n" name body + else + Printf.sprintf "%s(%s) = %s;\n" name (String.concat ", " params) body + +(* ============================================================================ + Driver + ============================================================================ *) + +let pick_entry (program : program) : fn_decl = + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) + program.prog_decls in + let by_name n = List.find_opt (fun fd -> fd.fd_name.name = n) fns in + match by_name "process" with + | Some fd -> fd + | None -> + match by_name "main" with + | Some fd -> fd + | None -> + match fns with + | fd :: _ -> fd + | [] -> unsupported "no function found to lower as Faust process" + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf + "// Generated by AffineScript compiler (Faust DSP sublanguage)\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + let entry = pick_entry program in + let entry_name = entry.fd_name.name in + user_fns := List.filter_map (function + | TopFn fd -> Some fd.fd_name.name + | _ -> None + ) program.prog_decls; + let other_fns = List.filter_map (function + | TopFn fd when fd.fd_name.name <> entry_name -> Some fd + | _ -> None + ) program.prog_decls in + let _ = entry_name in + (* Emit auxiliary functions first so Faust can resolve forward references. *) + List.iter (fun fd -> Buffer.add_string buf (gen_function fd)) other_fns; + if other_fns <> [] then Buffer.add_char buf '\n'; + (* Emit the entry as `process` — Faust's required entry-point name. *) + Buffer.add_string buf (gen_function ~as_entry:true entry); + Buffer.contents buf + +let codegen_faust (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Faust_unsupported msg -> Error ("Faust backend: " ^ msg) + | Failure msg -> Error ("Faust codegen error: " ^ msg) + | e -> Error ("Faust codegen error: " ^ Printexc.to_string e) diff --git a/lib/gleam_codegen.ml b/lib/gleam_codegen.ml new file mode 100644 index 0000000..072602b --- /dev/null +++ b/lib/gleam_codegen.ml @@ -0,0 +1,173 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Gleam emitter (BEAM target). + + Lowers a subset of AffineScript to Gleam, which is a Hindley-Milner + typed language compiling to Erlang (BEAM) and JavaScript. *) + +open Ast + +let gleam_reserved = [ + "as"; "assert"; "case"; "const"; "external"; "fn"; "if"; "import"; + "let"; "opaque"; "panic"; "pub"; "todo"; "type"; "use"; +] + +let mangle s = if List.mem s gleam_reserved then s ^ "_" else s + +let rec gleam_type = function + | TyCon id when id.name = "Int" -> "Int" + | TyCon id when id.name = "Float" -> "Float" + | TyCon id when id.name = "Bool" -> "Bool" + | TyCon id when id.name = "String" -> "String" + | TyCon id when id.name = "Unit" -> "Nil" + | TyCon id -> mangle id.name + | TyTuple [] -> "Nil" + | TyTuple ts -> "#(" ^ String.concat ", " (List.map gleam_type ts) ^ ")" + | TyOwn t | TyRef t | TyMut t -> gleam_type t + | _ -> "Dynamic" + +let ret_type = function None -> "Nil" | Some t -> gleam_type t + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + gen_expr callee ^ "(" ^ String.concat ", " (List.map gen_expr args) ^ ")" + | ExprBinary (a, op, b) -> + (* Gleam has separate Int/Float operators: +, +., -, -., *, *., /, /. *) + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpConcat -> "<>" + | OpBitAnd | OpBitOr | OpBitXor | OpShl | OpShr -> + failwith "Gleam backend: bitwise ops need int_bitwise stdlib" + in + "{ " ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ " }" + | ExprUnary (OpNeg, x) -> "{ 0 - " ^ gen_expr x ^ " }" + | ExprUnary (OpNot, x) -> "!" ^ gen_expr x + | ExprUnary _ -> "panic as \"Gleam backend: unsupported unary\"" + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "Nil" in + Printf.sprintf "case %s { True -> %s False -> %s }" c t f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let v = gen_expr el_value in + let body = match el_body with Some e -> gen_expr e | None -> "Nil" in + Printf.sprintf "{ let %s = %s\n %s }" var v body + | ExprBlock blk -> gen_block blk + | ExprTuple es -> "#(" ^ String.concat ", " (List.map gen_expr es) ^ ")" + | ExprTupleIndex (e, n) -> gen_expr e ^ "." ^ string_of_int n + | ExprRecord { er_fields; _ } -> + let fs = List.map (fun (id, e_opt) -> + let v = match e_opt with Some e -> gen_expr e | None -> mangle id.name in + Printf.sprintf "%s: %s" (mangle id.name) v + ) er_fields in + "(" ^ String.concat ", " fs ^ ")" (* placeholder — real form needs ctor name *) + | ExprField (record, field) -> gen_expr record ^ "." ^ mangle field.name + | ExprVariant (_ty, ctor) -> ctor.name (* Gleam variants keep TitleCase *) + | ExprMatch { em_scrutinee; em_arms } -> + let arms = List.map (fun arm -> + Printf.sprintf "%s -> %s" (gen_pattern arm.ma_pat) (gen_expr arm.ma_body) + ) em_arms in + Printf.sprintf "case %s { %s }" (gen_expr em_scrutinee) (String.concat " " arms) + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | _ -> "panic as \"Gleam backend: unsupported expression\"" + +and gen_pattern (p : pattern) : string = + match p with + | PatWildcard _ -> "_" + | PatVar id -> mangle id.name + | PatLit lit -> gen_lit lit + | PatCon (id, args) -> + if args = [] then id.name + else id.name ^ "(" ^ String.concat ", " (List.map gen_pattern args) ^ ")" + | PatTuple ps -> "#(" ^ String.concat ", " (List.map gen_pattern ps) ^ ")" + | PatRecord _ -> "_" (* Gleam record patterns need the constructor name *) + | PatAs (id, _) -> mangle id.name + | PatOr (p, _) -> gen_pattern p + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "True" + | LitBool (false, _) -> "False" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "\"" ^ Char.escaped c ^ "\"" + | LitUnit _ -> "Nil" + +and gen_block (blk : block) : string = + let stmts = List.map gen_stmt blk.blk_stmts in + let tail = match blk.blk_expr with Some e -> gen_expr e | None -> "Nil" in + "{\n " ^ String.concat "\n " stmts ^ "\n " ^ tail ^ "\n}" + +and gen_stmt = function + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + Printf.sprintf "let %s = %s" (mangle id.name) (gen_expr sl_value) + | StmtLet _ -> "" + | StmtExpr e -> gen_expr e + | _ -> "" + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let params = String.concat ", " + (List.map (fun (p : param) -> + Printf.sprintf "%s: %s" (mangle p.p_name.name) (gleam_type p.p_ty)) + fd.fd_params) in + let ret = ret_type fd.fd_ret_ty in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + Printf.sprintf "pub fn %s(%s) -> %s {\n %s\n}\n\n" name params ret body + +let gen_type_decl (td : type_decl) : string = + let name = td.td_name.name in + match td.td_body with + | TyAlias (TyRecord (fields, _)) -> + (* Gleam's record-shape: `pub type Point { Point(x: Int, y: Int) }`. *) + let fs = List.map (fun (rf : row_field) -> + Printf.sprintf "%s: %s" (mangle rf.rf_name.name) (gleam_type rf.rf_ty) + ) fields in + Printf.sprintf "pub type %s { %s(%s) }\n\n" name name (String.concat ", " fs) + | TyAlias t -> Printf.sprintf "pub type %s = %s\n\n" name (gleam_type t) + | TyStruct fields -> + let fs = List.map (fun (sf : struct_field) -> + Printf.sprintf "%s: %s" (mangle sf.sf_name.name) (gleam_type sf.sf_ty) + ) fields in + Printf.sprintf "pub type %s { %s(%s) }\n\n" name name (String.concat ", " fs) + | TyEnum variants -> + let vs = List.map (fun (vd : variant_decl) -> + let tys = List.map gleam_type vd.vd_fields in + let body = if tys = [] then "" else "(" ^ String.concat ", " tys ^ ")" in + Printf.sprintf " %s%s" vd.vd_name.name body + ) variants in + Printf.sprintf "pub type %s {\n%s\n}\n\n" name (String.concat "\n" vs) + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + List.iter (function + | TopType td -> Buffer.add_string buf (gen_type_decl td) + | _ -> () + ) program.prog_decls; + List.iter (function + | TopFn fd -> Buffer.add_string buf (gen_function fd) + | _ -> () + ) program.prog_decls; + Buffer.contents buf + +let codegen_gleam (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("Gleam codegen error: " ^ msg) + | e -> Error ("Gleam codegen error: " ^ Printexc.to_string e) diff --git a/lib/js_codegen.ml b/lib/js_codegen.ml new file mode 100644 index 0000000..a8ceb9a --- /dev/null +++ b/lib/js_codegen.ml @@ -0,0 +1,487 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** JavaScript Code Generator (MVP). + + Translates AffineScript AST to ES2020 JavaScript source code. + Targets Deno/Node — no Web Audio, no DOM. Output is a single self-contained + file that includes a minimal runtime prelude (Some/None/Ok/Err builders, + println). Effects are erased at this layer; IO operations call the prelude. + + Phase 1 (this file): functions, arithmetic, control flow, let, tuples, + records, simple match, Option/Result constructors, try/catch. Sufficient + for the conformance subset that does not depend on ownership at runtime. +*) + +open Ast + +(* ============================================================================ + Code Generation Context + ============================================================================ *) + +type codegen_ctx = { + output : Buffer.t; + indent : int; + symbols : Symbol.t; + in_function : bool; +} + +let create_ctx symbols = { + output = Buffer.create 1024; + indent = 0; + symbols; + in_function = false; +} + +let emit ctx str = + Buffer.add_string ctx.output str + +let emit_line ctx str = + let spaces = String.make (ctx.indent * 2) ' ' in + Buffer.add_string ctx.output spaces; + Buffer.add_string ctx.output str; + Buffer.add_char ctx.output '\n' + +let increase_indent ctx = { ctx with indent = ctx.indent + 1 } +let decrease_indent ctx = { ctx with indent = max 0 (ctx.indent - 1) } + +(* ============================================================================ + Runtime prelude + + Emitted once at the top of every output file. Keeps generated code free of + library dependencies — `deno run foo.js` or `node foo.js` is enough. + ============================================================================ *) + +let prelude = {|// ---- AffineScript JS runtime (MVP) ---- +const Some = (value) => ({ tag: "Some", value }); +const None = { tag: "None" }; +const Ok = (value) => ({ tag: "Ok", value }); +const Err = (error) => ({ tag: "Err", error }); +const Unit = null; +const print = (s) => { (typeof Deno !== "undefined" ? Deno.stdout.writeSync(new TextEncoder().encode(String(s))) : process.stdout.write(String(s))); }; +const println = (s) => { console.log(String(s)); }; +// ---- end runtime ---- + +|} + +(* ============================================================================ + Identifier sanitisation + + AffineScript identifiers are mostly JS-safe. Reserved keywords are renamed + with a trailing underscore so generated code parses. + ============================================================================ *) + +let js_reserved = [ + "abstract"; "arguments"; "await"; "boolean"; "break"; "byte"; "case"; + "catch"; "char"; "class"; "const"; "continue"; "debugger"; "default"; + "delete"; "do"; "double"; "else"; "enum"; "eval"; "export"; "extends"; + "false"; "final"; "finally"; "float"; "for"; "function"; "goto"; "if"; + "implements"; "import"; "in"; "instanceof"; "int"; "interface"; "let"; + "long"; "native"; "new"; "null"; "package"; "private"; "protected"; + "public"; "return"; "short"; "static"; "super"; "switch"; "synchronized"; + "this"; "throw"; "throws"; "transient"; "true"; "try"; "typeof"; "var"; + "void"; "volatile"; "while"; "with"; "yield"; +] + +let mangle (name : string) : string = + if List.mem name js_reserved then name ^ "_" + else name + +(* ============================================================================ + Expression Code Generation + ============================================================================ *) + +let rec gen_expr ctx (expr : expr) : string = + match expr with + | ExprLit lit -> gen_literal lit + | ExprVar name -> mangle name.name + | ExprApp (func, args) -> + let func_str = gen_expr ctx func in + let arg_strs = List.map (gen_expr ctx) args in + func_str ^ "(" ^ String.concat ", " arg_strs ^ ")" + | ExprBinary (e1, op, e2) -> + let op_str = match op with + | OpAdd -> "+" + | OpSub -> "-" + | OpMul -> "*" + | OpDiv -> "/" + | OpMod -> "%" + | OpEq -> "===" + | OpNe -> "!==" + | OpLt -> "<" + | OpLe -> "<=" + | OpGt -> ">" + | OpGe -> ">=" + | OpAnd -> "&&" + | OpOr -> "||" + | OpBitAnd -> "&" + | OpBitOr -> "|" + | OpBitXor -> "^" + | OpShl -> "<<" + | OpShr -> ">>" + | OpConcat -> "+" (* JS string/array overload *) + in + "(" ^ gen_expr ctx e1 ^ " " ^ op_str ^ " " ^ gen_expr ctx e2 ^ ")" + | ExprUnary (op, e) -> + (match op with + | OpNeg -> "(-" ^ gen_expr ctx e ^ ")" + | OpNot -> "(!" ^ gen_expr ctx e ^ ")" + | OpBitNot -> "(~" ^ gen_expr ctx e ^ ")" + | OpRef -> "({ get: () => " ^ gen_expr ctx e ^ ", set: (_) => {} })" + | OpDeref -> "(" ^ gen_expr ctx e ^ ".get())") + | ExprIf { ei_cond; ei_then; ei_else } -> + let cond_str = gen_expr ctx ei_cond in + let then_str = gen_expr ctx ei_then in + let else_str = match ei_else with + | Some e -> gen_expr ctx e + | None -> "Unit" + in + "(" ^ cond_str ^ " ? " ^ then_str ^ " : " ^ else_str ^ ")" + | ExprLet { el_pat; el_value; el_body; el_mut; el_quantity = _; el_ty = _ } -> + let pat_str = gen_pattern ctx el_pat in + let val_str = gen_expr ctx el_value in + let kw = if el_mut then "let" else "const" in + (match el_body with + | Some body -> + let body_str = gen_expr ctx body in + "((() => { " ^ kw ^ " " ^ pat_str ^ " = " ^ val_str ^ "; return " ^ + body_str ^ "; })())" + | None -> + (* Statement-position let folded into expression: emit IIFE returning Unit. *) + "((() => { " ^ kw ^ " " ^ pat_str ^ " = " ^ val_str ^ "; return Unit; })())") + | ExprTuple exprs -> + let expr_strs = List.map (gen_expr ctx) exprs in + "[" ^ String.concat ", " expr_strs ^ "]" + | ExprArray exprs -> + let expr_strs = List.map (gen_expr ctx) exprs in + "[" ^ String.concat ", " expr_strs ^ "]" + | ExprIndex (arr, idx) -> + gen_expr ctx arr ^ "[" ^ gen_expr ctx idx ^ "]" + | ExprTupleIndex (e, n) -> + gen_expr ctx e ^ "[" ^ string_of_int n ^ "]" + | ExprRecord { er_fields; er_spread } -> + let field_strs = List.map (fun (name, e_opt) -> + let val_str = match e_opt with + | Some e -> gen_expr ctx e + | None -> mangle name.name (* punning: { x } -> { x: x } *) + in + mangle name.name ^ ": " ^ val_str + ) er_fields in + let spread_str = match er_spread with + | Some e -> "...(" ^ gen_expr ctx e ^ "), " + | None -> "" + in + "({ " ^ spread_str ^ String.concat ", " field_strs ^ " })" + | ExprField (record, field) -> + gen_expr ctx record ^ "." ^ mangle field.name + | ExprMatch { em_scrutinee; em_arms } -> + gen_match ctx em_scrutinee em_arms + | ExprBlock block -> + gen_block_expr ctx block + | ExprReturn (Some e) -> + "(() => { return " ^ gen_expr ctx e ^ "; })()" + | ExprReturn None -> + "(() => { return Unit; })()" + | ExprLambda { elam_params; elam_body; elam_ret_ty = _ } -> + let param_strs = List.map (fun (p : param) -> mangle p.p_name.name) elam_params in + "((" ^ String.concat ", " param_strs ^ ") => " ^ gen_expr ctx elam_body ^ ")" + | ExprTry { et_body; et_catch; et_finally } -> + gen_try ctx et_body et_catch et_finally + | ExprVariant (ty, ctor) -> + (* `Type::Variant` — emit a tagged object factory. Special-cases for + the prelude builders so `Option::None` / `Result::Ok` map directly. *) + (match ty.name, ctor.name with + | _, "None" -> "None" + | _, "Some" -> "Some" + | _, "Ok" -> "Ok" + | _, "Err" -> "Err" + | _, name -> Printf.sprintf "({ tag: %S })" name) + | ExprSpan (inner, _) -> gen_expr ctx inner + | ExprRowRestrict (e, _field) -> gen_expr ctx e (* runtime no-op *) + | ExprHandle { eh_body; eh_handlers = _ } -> + (* Effect handlers are erased at MVP — IO collapses to direct calls. *) + gen_expr ctx eh_body + | ExprResume (Some e) -> gen_expr ctx e + | ExprResume None -> "Unit" + | ExprUnsafe _ -> + "(() => { throw new Error('unsafe op not supported in JS backend'); })()" + +and gen_literal (lit : literal) : string = + match lit with + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + (* OCaml's string_of_float can produce "1." — patch to "1.0" so JS parses + identically and the output is stable. *) + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "\"" ^ Char.escaped c ^ "\"" + | LitUnit _ -> "Unit" + +and gen_pattern ctx (pat : pattern) : string = + (* Used in binder positions: let x = ..., function params, for x in ... *) + match pat with + | PatWildcard _ -> "_" + | PatVar name -> mangle name.name + | PatLit _ -> "_" (* Literal patterns can't bind; only meaningful in match *) + | PatTuple pats -> + let pat_strs = List.map (gen_pattern ctx) pats in + "[" ^ String.concat ", " pat_strs ^ "]" + | PatRecord (fields, _) -> + let strs = List.map (fun (n, sub) -> + match sub with + | None -> mangle n.name + | Some sub -> mangle n.name ^ ": " ^ gen_pattern ctx sub + ) fields in + "{ " ^ String.concat ", " strs ^ " }" + | PatAs (id, _) -> mangle id.name + | PatCon (id, _) -> mangle id.name (* approximate *) + | PatOr (p, _) -> gen_pattern ctx p + +and gen_match ctx scrutinee arms = + (* Lower `match` to an IIFE that destructures the scrutinee once and runs an + if-else cascade. Each arm gets a freshly-scoped block so binders don't + collide. *) + let scrutinee_str = gen_expr ctx scrutinee in + let scrut_var = "__scrut" in + let rec gen_arms = function + | [] -> + "throw new Error(\"non-exhaustive match\");" + | arm :: rest -> + let cond = gen_pattern_test scrut_var arm.ma_pat in + let bindings = gen_pattern_bindings scrut_var arm.ma_pat in + let guard = match arm.ma_guard with + | Some g -> " && (" ^ gen_expr ctx g ^ ")" + | None -> "" + in + let body = gen_expr ctx arm.ma_body in + let inner_bindings = + if bindings = "" then "" + else bindings ^ " " + in + let prefix = + (* When we have bindings, emit them inside the branch so they are + scoped to that arm and visible to the guard check. *) + if bindings = "" then + "if (" ^ cond ^ guard ^ ") { return " ^ body ^ "; }" + else if arm.ma_guard = None then + "if (" ^ cond ^ ") { " ^ inner_bindings ^ "return " ^ body ^ "; }" + else + "if (" ^ cond ^ ") { " ^ inner_bindings ^ + "if (" ^ (match arm.ma_guard with Some g -> gen_expr ctx g | None -> "true") ^ + ") { return " ^ body ^ "; } }" + in + prefix ^ " " ^ gen_arms rest + in + "((" ^ scrut_var ^ ") => { " ^ gen_arms arms ^ " })(" ^ scrutinee_str ^ ")" + +and gen_pattern_test scrut pat = + match pat with + | PatWildcard _ | PatVar _ -> "true" + | PatLit lit -> scrut ^ " === " ^ gen_literal lit + | PatCon (id, _) -> + (* Tagged-union variant: { tag: "Some", value: ... } *) + scrut ^ ".tag === " ^ Printf.sprintf "%S" id.name + | PatTuple pats -> + let conds = List.mapi (fun i p -> + gen_pattern_test (scrut ^ "[" ^ string_of_int i ^ "]") p + ) pats in + String.concat " && " (("Array.isArray(" ^ scrut ^ ")") :: conds) + | PatRecord (fields, _) -> + let conds = List.map (fun (n, sub) -> + match sub with + | None -> "true" + | Some sub -> gen_pattern_test (scrut ^ "." ^ mangle n.name) sub + ) fields in + String.concat " && " conds + | PatAs (_, p) -> gen_pattern_test scrut p + | PatOr (p1, p2) -> + "((" ^ gen_pattern_test scrut p1 ^ ") || (" ^ gen_pattern_test scrut p2 ^ "))" + +and gen_pattern_bindings scrut pat = + (* Emit `const x = scrut.;` declarations for every binder reachable + from this pattern, skipping the wildcard and literal cases. *) + let buf = Buffer.create 64 in + let rec walk path = function + | PatWildcard _ | PatLit _ -> () + | PatVar id -> + Buffer.add_string buf + ("const " ^ mangle id.name ^ " = " ^ path ^ "; ") + | PatTuple pats -> + List.iteri (fun i p -> + walk (path ^ "[" ^ string_of_int i ^ "]") p + ) pats + | PatRecord (fields, _) -> + List.iter (fun (n, sub) -> + let sub_path = path ^ "." ^ mangle n.name in + match sub with + | None -> + Buffer.add_string buf + ("const " ^ mangle n.name ^ " = " ^ sub_path ^ "; ") + | Some sub -> walk sub_path sub + ) fields + | PatCon (_, args) -> + (* Convention: variant payload lives at .value (single-arg, like Some) + or .values[i] (multi-arg). Use .value for arity 1 to match prelude. *) + (match args with + | [] -> () + | [single] -> walk (path ^ ".value") single + | many -> + List.iteri (fun i p -> + walk (path ^ ".values[" ^ string_of_int i ^ "]") p + ) many) + | PatAs (id, sub) -> + Buffer.add_string buf ("const " ^ mangle id.name ^ " = " ^ path ^ "; "); + walk path sub + | PatOr (p, _) -> walk path p + in + walk scrut pat; + Buffer.contents buf + +and gen_block_expr ctx block = + (* JS has no block-as-expression. Emit an IIFE; statements run in order, and + the trailing expression (if any) becomes the return value. *) + let body = Buffer.create 64 in + List.iter (fun s -> + Buffer.add_string body (gen_stmt ctx s); + Buffer.add_string body " " + ) block.blk_stmts; + let result = match block.blk_expr with + | Some e -> "return " ^ gen_expr ctx e ^ ";" + | None -> "return Unit;" + in + "(() => { " ^ Buffer.contents body ^ result ^ " })()" + +and gen_try ctx body catch finally = + let body_str = gen_block_expr ctx body in + let catch_str = match catch with + | None | Some [] -> "catch (__e) { throw __e; }" + | Some (arm :: _) -> + let bind = match arm.ma_pat with + | PatVar id -> "const " ^ mangle id.name ^ " = __e; " + | _ -> "" + in + "catch (__e) { " ^ bind ^ "return " ^ gen_expr ctx arm.ma_body ^ "; }" + in + let finally_str = match finally with + | None -> "" + | Some blk -> " finally { " ^ gen_block_expr ctx blk ^ "; }" + in + "(() => { try { return " ^ body_str ^ "; } " ^ catch_str ^ finally_str ^ " })()" + +and gen_stmt ctx (stmt : stmt) : string = + match stmt with + | StmtLet { sl_pat; sl_value; sl_mut; sl_quantity = _; sl_ty = _ } -> + let pat_str = gen_pattern ctx sl_pat in + let val_str = gen_expr ctx sl_value in + let kw = if sl_mut then "let" else "const" in + kw ^ " " ^ pat_str ^ " = " ^ val_str ^ ";" + | StmtExpr e -> + gen_expr ctx e ^ ";" + | StmtAssign (lhs, op, rhs) -> + let op_str = match op with + | AssignEq -> "=" + | AssignAdd -> "+=" + | AssignSub -> "-=" + | AssignMul -> "*=" + | AssignDiv -> "/=" + in + gen_expr ctx lhs ^ " " ^ op_str ^ " " ^ gen_expr ctx rhs ^ ";" + | StmtWhile (cond, body) -> + "while (" ^ gen_expr ctx cond ^ ") { " ^ + String.concat " " (List.map (gen_stmt ctx) body.blk_stmts) ^ + (match body.blk_expr with + | Some e -> " " ^ gen_expr ctx e ^ ";" + | None -> "") ^ + " }" + | StmtFor (pat, iter, body) -> + "for (const " ^ gen_pattern ctx pat ^ " of " ^ gen_expr ctx iter ^ ") { " ^ + String.concat " " (List.map (gen_stmt ctx) body.blk_stmts) ^ + (match body.blk_expr with + | Some e -> " " ^ gen_expr ctx e ^ ";" + | None -> "") ^ + " }" + +(* ============================================================================ + Top-Level Declaration Code Generation + ============================================================================ *) + +let gen_function ctx (fd : fn_decl) : unit = + let name = mangle fd.fd_name.name in + let param_strs = List.map (fun (p : param) -> mangle p.p_name.name) fd.fd_params in + let header = Printf.sprintf "function %s(%s) {" name (String.concat ", " param_strs) in + emit_line ctx header; + let ctx_body = increase_indent { ctx with in_function = true } in + (match fd.fd_body with + | FnExpr body_expr -> + emit_line ctx_body ("return " ^ gen_expr ctx_body body_expr ^ ";") + | FnBlock block -> + List.iter (fun s -> emit_line ctx_body (gen_stmt ctx_body s)) block.blk_stmts; + (match block.blk_expr with + | Some e -> emit_line ctx_body ("return " ^ gen_expr ctx_body e ^ ";") + | None -> ())); + emit_line (decrease_indent ctx_body) "}"; + emit ctx "\n" + +let gen_type_decl ctx (td : type_decl) : unit = + (* Phase 1: emit constructor factories for enum variants so pattern matches + and `Type::Variant` references both work. Structs and aliases are erased. *) + match td.td_body with + | TyEnum variants -> + List.iter (fun (vd : variant_decl) -> + let name = mangle vd.vd_name.name in + let arity = List.length vd.vd_fields in + if arity = 0 then + emit_line ctx (Printf.sprintf "const %s = { tag: \"%s\" };" name vd.vd_name.name) + else if arity = 1 then + emit_line ctx + (Printf.sprintf "const %s = (value) => ({ tag: \"%s\", value });" + name vd.vd_name.name) + else + let params = List.init arity (fun i -> "v" ^ string_of_int i) in + emit_line ctx + (Printf.sprintf "const %s = (%s) => ({ tag: \"%s\", values: [%s] });" + name (String.concat ", " params) + vd.vd_name.name (String.concat ", " params)) + ) variants; + emit ctx "\n" + | TyStruct _ | TyAlias _ -> + emit_line ctx (Printf.sprintf "// type %s (erased)" td.td_name.name) + +let gen_top_level ctx (top : top_level) : unit = + match top with + | TopFn fd -> gen_function ctx fd + | TopType td -> gen_type_decl ctx td + | TopConst { tc_name; tc_value; _ } -> + emit_line ctx + (Printf.sprintf "const %s = %s;" (mangle tc_name.name) + (gen_expr ctx tc_value)) + | TopEffect _ -> emit_line ctx "// effect declaration (erased)" + | TopTrait _ -> emit_line ctx "// trait declaration (erased)" + | TopImpl _ -> emit_line ctx "// impl block (erased)" + +(* ============================================================================ + Main Code Generation Entry Point + ============================================================================ *) + +let generate (program : program) (symbols : Symbol.t) : string = + let ctx = create_ctx symbols in + emit_line ctx "// Generated by AffineScript compiler"; + emit_line ctx "// SPDX-License-Identifier: PMPL-1.0-or-later"; + emit ctx prelude; + List.iter (gen_top_level ctx) program.prog_decls; + (* If a `main` function exists, invoke it so `node foo.js` actually runs. *) + let has_main = List.exists (function + | TopFn fd -> fd.fd_name.name = "main" + | _ -> false + ) program.prog_decls in + if has_main then emit_line ctx "main();"; + Buffer.contents ctx.output + +let codegen_js (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("JS codegen error: " ^ msg) + | e -> Error ("JS codegen error: " ^ Printexc.to_string e) diff --git a/lib/lean_codegen.ml b/lib/lean_codegen.ml new file mode 100644 index 0000000..02b545c --- /dev/null +++ b/lib/lean_codegen.ml @@ -0,0 +1,114 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Lean 4 emitter (MVP, dependent-type proof target). + + Lowers Int/Bool/Float functions to Lean 4 [def] declarations. *) + +open Ast + +let lean_reserved = [ + "abbrev"; "axiom"; "by"; "calc"; "class"; "def"; "deriving"; "do"; "elab"; + "else"; "end"; "example"; "fun"; "have"; "if"; "import"; "in"; "inductive"; + "instance"; "let"; "macro"; "match"; "mut"; "namespace"; "open"; "section"; + "set_option"; "show"; "structure"; "syntax"; "then"; "theorem"; "where"; + "with"; "true"; "false"; +] + +let mangle s = if List.mem s lean_reserved then s ^ "_" else s + +let rec lean_type = function + | TyCon id when id.name = "Int" -> "Int" + | TyCon id when id.name = "Bool" -> "Bool" + | TyCon id when id.name = "Float" -> "Float" + | TyCon id when id.name = "String" -> "String" + | TyCon id when id.name = "Unit" -> "Unit" + | TyCon id -> mangle id.name + | TyOwn t | TyRef t | TyMut t -> lean_type t + | _ -> "Int" + +let ret_type = function None -> "Unit" | Some t -> lean_type t + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + let f = gen_expr callee in + let xs = List.map (fun a -> "(" ^ gen_expr a ^ ")") args in + f ^ " " ^ String.concat " " xs + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpConcat -> "++" + | _ -> "+" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (OpNeg, x) -> "(- " ^ gen_expr x ^ ")" + | ExprUnary (OpNot, x) -> "(! " ^ gen_expr x ^ ")" + | ExprUnary _ -> "0" + | ExprIf { ei_cond; ei_then; ei_else } -> + let f = match ei_else with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(if %s then %s else %s)" (gen_expr ei_cond) (gen_expr ei_then) f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let body = match el_body with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(let %s := %s; %s)" var (gen_expr el_value) body + | ExprBlock blk -> gen_block blk + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | _ -> "0" + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "'" ^ Char.escaped c ^ "'" + | LitUnit _ -> "()" + +and gen_block (blk : block) : string = + let rec fold = function + | [] -> (match blk.blk_expr with Some e -> gen_expr e | None -> "()") + | StmtLet { sl_pat = PatVar id; sl_value; _ } :: rest -> + Printf.sprintf "(let %s := %s; %s)" (mangle id.name) (gen_expr sl_value) (fold rest) + | _ :: rest -> fold rest + in + fold blk.blk_stmts + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let params = match fd.fd_params with + | [] -> "" + | _ -> " " ^ String.concat " " (List.map (fun (p : param) -> + Printf.sprintf "(%s : %s)" (mangle p.p_name.name) (lean_type p.p_ty)) + fd.fd_params) + in + let ret = ret_type fd.fd_ret_ty in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + Printf.sprintf "def %s%s : %s :=\n %s\n\n" name params ret body + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "-- Generated by AffineScript compiler (Lean 4)\n"; + Buffer.add_string buf "-- SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + List.iter (function + | TopFn fd -> Buffer.add_string buf (gen_function fd) + | _ -> () + ) program.prog_decls; + Buffer.contents buf + +let codegen_lean (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure m -> Error ("Lean codegen error: " ^ m) + | e -> Error ("Lean codegen error: " ^ Printexc.to_string e) diff --git a/lib/llvm_codegen.ml b/lib/llvm_codegen.ml new file mode 100644 index 0000000..7b7d780 --- /dev/null +++ b/lib/llvm_codegen.ml @@ -0,0 +1,258 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** LLVM IR text emitter (MVP). + + Produces LLVM IR in the textual [.ll] form that [llc] reads. We emit + SSA via a fresh-name counter; control flow (if/else) becomes basic + blocks with branches; let-bindings become named SSA values. + + Scope: Int (i64) and Float (double) only. No tuples, records, strings, + or heap allocation. Each function body is a single FnExpr or FnBlock + that returns a scalar. This is enough to take the LLVM toolchain to + x86-64, ARM64, RISC-V, AVR, PTX, AMDGPU via [llc -march=...]. *) + +open Ast + +exception Llvm_unsupported of string +let unsupported m = raise (Llvm_unsupported m) + +(* ============================================================================ + Per-function emit state — SSA / block counter. + ============================================================================ *) + +type fstate = { + mutable next_ssa : int; + mutable next_block : int; + body : Buffer.t; + mutable env : (string * (string * string)) list; + mutable current_blk : string; (* label of the block currently being filled *) +} + +let new_fstate () = { + next_ssa = 0; next_block = 0; body = Buffer.create 256; env = []; + current_blk = "entry"; +} + +let fresh_ssa st = + let n = st.next_ssa in st.next_ssa <- n + 1; + Printf.sprintf "%%v%d" n + +let fresh_block st label = + let n = st.next_block in st.next_block <- n + 1; + Printf.sprintf "%s%d" label n + +let emit_line st s = + Buffer.add_string st.body s; + Buffer.add_char st.body '\n' + +let bind st name ssa ty = + st.env <- (name, (ssa, ty)) :: st.env + +let lookup st name = + try List.assoc name st.env + with Not_found -> unsupported ("unbound: " ^ name) + +let llvm_type = function + | TyCon id when id.name = "Int" -> "i64" + | TyCon id when id.name = "Float" -> "double" + | TyCon id when id.name = "Bool" -> "i1" + | TyCon id when id.name = "Unit" -> "void" + | TyOwn t | TyRef t | TyMut t -> + (* Recurse rather than calling llvm_type recursively here — pattern + matches don't reduce; we need a separate function. *) + (match t with + | TyCon id when id.name = "Int" -> "i64" + | TyCon id when id.name = "Float" -> "double" + | TyCon id when id.name = "Bool" -> "i1" + | _ -> unsupported "complex type not supported in LLVM backend") + | _ -> unsupported "type not supported in LLVM backend" + +let ret_type = function None -> "void" | Some t -> llvm_type t + +(* ============================================================================ + Expression compilation: returns (ssa_name, llvm_type). + ============================================================================ *) + +let rec gen_expr (st : fstate) (e : expr) : string * string = + match e with + | ExprLit (LitInt (n, _)) -> (string_of_int n, "i64") + | ExprLit (LitBool (true, _)) -> ("1", "i1") + | ExprLit (LitBool (false, _))-> ("0", "i1") + | ExprLit (LitFloat (f, _)) -> + let s = string_of_float f in + let s = if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s in + (s, "double") + | ExprLit _ -> unsupported "non-numeric literal" + | ExprVar id -> + let (ssa, ty) = lookup st id.name in + (ssa, ty) + | ExprBinary (a, op, b) -> + let (av, ty) = gen_expr st a in + let (bv, _) = gen_expr st b in + let dst = fresh_ssa st in + let opcode = match op, ty with + | OpAdd, "i64" -> "add" | OpAdd, "double" -> "fadd" + | OpSub, "i64" -> "sub" | OpSub, "double" -> "fsub" + | OpMul, "i64" -> "mul" | OpMul, "double" -> "fmul" + | OpDiv, "i64" -> "sdiv" | OpDiv, "double" -> "fdiv" + | OpMod, "i64" -> "srem" | OpMod, "double" -> "frem" + | OpBitAnd, _ -> "and" + | OpBitOr, _ -> "or" + | OpBitXor, _ -> "xor" + | OpShl, _ -> "shl" + | OpShr, _ -> "ashr" + | _ -> unsupported "comparison / logical op needs different lowering" + in + let result_ty = ty in + emit_line st (Printf.sprintf " %s = %s %s %s, %s" dst opcode result_ty av bv); + (dst, result_ty) + | ExprUnary (OpNeg, x) -> + let (xv, ty) = gen_expr st x in + let dst = fresh_ssa st in + (match ty with + | "i64" -> emit_line st (Printf.sprintf " %s = sub i64 0, %s" dst xv) + | "double" -> emit_line st (Printf.sprintf " %s = fneg double %s" dst xv) + | _ -> unsupported "OpNeg on non-numeric"); + (dst, ty) + | ExprIf { ei_cond; ei_then; ei_else } -> + let (cv, _) = gen_expr_bool st ei_cond in + let then_lbl = fresh_block st "then" in + let else_lbl = fresh_block st "else" in + let cont_lbl = fresh_block st "cont" in + emit_line st + (Printf.sprintf " br i1 %s, label %%%s, label %%%s" cv then_lbl else_lbl); + emit_line st (then_lbl ^ ":"); + st.current_blk <- then_lbl; + let (tv, ty) = gen_expr st ei_then in + let then_end = st.current_blk in (* may differ from then_lbl if body emitted blocks *) + emit_line st (Printf.sprintf " br label %%%s" cont_lbl); + emit_line st (else_lbl ^ ":"); + st.current_blk <- else_lbl; + let (ev, _) = match ei_else with + | Some e -> gen_expr st e + | None -> unsupported "if without else has no value in LLVM backend" + in + let else_end = st.current_blk in + emit_line st (Printf.sprintf " br label %%%s" cont_lbl); + emit_line st (cont_lbl ^ ":"); + st.current_blk <- cont_lbl; + let dst = fresh_ssa st in + emit_line st + (Printf.sprintf " %s = phi %s [ %s, %%%s ], [ %s, %%%s ]" + dst ty tv then_end ev else_end); + (dst, ty) + | ExprLet { el_pat = PatVar id; el_value; el_body = Some body; _ } -> + let (v, ty) = gen_expr st el_value in + bind st id.name v ty; + gen_expr st body + | ExprBlock blk -> + List.iter (fun s -> + match s with + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + let (v, ty) = gen_expr st sl_value in + bind st id.name v ty + | StmtExpr e -> ignore (gen_expr st e) + | _ -> unsupported "stmt form not supported in LLVM block" + ) blk.blk_stmts; + (match blk.blk_expr with + | Some e -> gen_expr st e + | None -> ("0", "i64")) + | ExprApp (callee, args) -> + let name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect call" + in + let arg_pairs = List.map (gen_expr st) args in + let dst = fresh_ssa st in + let call_args = String.concat ", " + (List.map (fun (v, ty) -> Printf.sprintf "%s %s" ty v) arg_pairs) in + let ret_ty = "i64" in (* assumption; refined when we support typed lookups *) + emit_line st + (Printf.sprintf " %s = call %s @%s(%s)" dst ret_ty name call_args); + (dst, ret_ty) + | ExprSpan (inner, _) -> gen_expr st inner + | ExprReturn (Some e) -> + let (v, ty) = gen_expr st e in + emit_line st (Printf.sprintf " ret %s %s" ty v); + (v, ty) + | _ -> unsupported "expression form not supported in LLVM backend MVP" + +and gen_expr_bool (st : fstate) (e : expr) : string * string = + (* Comparison ops produce i1; numeric ops would not. Emit comparisons with + icmp/fcmp; treat boolean literals normally. *) + match e with + | ExprBinary (a, (OpEq|OpNe|OpLt|OpLe|OpGt|OpGe as op), b) -> + let (av, ty) = gen_expr st a in + let (bv, _) = gen_expr st b in + let pred = match op, ty with + | OpEq, "i64" -> "eq" | OpNe, "i64" -> "ne" + | OpLt, "i64" -> "slt" | OpLe, "i64" -> "sle" + | OpGt, "i64" -> "sgt" | OpGe, "i64" -> "sge" + | OpEq, "double" -> "oeq" | OpNe, "double" -> "one" + | OpLt, "double" -> "olt" | OpLe, "double" -> "ole" + | OpGt, "double" -> "ogt" | OpGe, "double" -> "oge" + | _ -> unsupported "comparison on non-numeric" + in + let dst = fresh_ssa st in + let cmp = if ty = "double" then "fcmp" else "icmp" in + emit_line st (Printf.sprintf " %s = %s %s %s %s, %s" dst cmp pred ty av bv); + (dst, "i1") + | ExprLit (LitBool (b, _)) -> + ((if b then "1" else "0"), "i1") + | _ -> + let (v, ty) = gen_expr st e in + if ty = "i1" then (v, "i1") + else + let dst = fresh_ssa st in + emit_line st (Printf.sprintf " %s = icmp ne %s %s, 0" dst ty v); + (dst, "i1") + +(* ============================================================================ + Function emission + ============================================================================ *) + +let gen_function (buf : Buffer.t) (fd : fn_decl) : unit = + let st = new_fstate () in + (* Bind parameters by their declared name. LLVM accepts named params if + we declare them so in the signature; the body then references [%name] + directly. *) + List.iter (fun (p : param) -> + let ty = llvm_type p.p_ty in + bind st p.p_name.name ("%" ^ p.p_name.name) ty + ) fd.fd_params; + emit_line st "entry:"; + let body_expr = match fd.fd_body with + | FnExpr e -> e + | FnBlock b -> ExprBlock b + in + let ret_ty = ret_type fd.fd_ret_ty in + let (rv, _) = gen_expr st body_expr in + if ret_ty = "void" then emit_line st " ret void" + else emit_line st (Printf.sprintf " ret %s %s" ret_ty rv); + let params_str = String.concat ", " + (List.map (fun (p : param) -> + Printf.sprintf "%s %%%s" (llvm_type p.p_ty) p.p_name.name) + fd.fd_params) in + Buffer.add_string buf + (Printf.sprintf "define %s @%s(%s) {\n" ret_ty fd.fd_name.name params_str); + Buffer.add_buffer buf st.body; + Buffer.add_string buf "}\n\n" + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "; Generated by AffineScript compiler\n"; + Buffer.add_string buf "; SPDX-License-Identifier: PMPL-1.0-or-later\n"; + Buffer.add_string buf "target triple = \"x86_64-unknown-linux-gnu\"\n\n"; + List.iter (function + | TopFn fd -> gen_function buf fd + | _ -> () + ) program.prog_decls; + Buffer.contents buf + +let codegen_llvm (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Llvm_unsupported m -> Error ("LLVM backend: " ^ m) + | Failure m -> Error ("LLVM codegen error: " ^ m) + | e -> Error ("LLVM codegen error: " ^ Printexc.to_string e) diff --git a/lib/lua_codegen.ml b/lib/lua_codegen.ml new file mode 100644 index 0000000..2f16eca --- /dev/null +++ b/lib/lua_codegen.ml @@ -0,0 +1,233 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Lua 5.x Emitter (MVP). + + Lowers a subset of AffineScript to Lua source. Numbers are unified as + Lua's native [number] type — Int and Float both map to it. *) + +open Ast + +let lua_reserved = [ + "and"; "break"; "do"; "else"; "elseif"; "end"; "false"; "for"; + "function"; "goto"; "if"; "in"; "local"; "nil"; "not"; "or"; "repeat"; + "return"; "then"; "true"; "until"; "while"; +] + +let mangle s = if List.mem s lua_reserved then s ^ "_" else s + +let prelude = {|-- AffineScript Lua runtime (MVP) +local function println(s) print(tostring(s)) end +local function as_print(s) io.write(tostring(s)) end +local Some = function(v) return { tag = "Some", value = v } end +local None = { tag = "None" } +local Ok = function(v) return { tag = "Ok", value = v } end +local Err = function(e) return { tag = "Err", error = e } end + +|} + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + gen_expr callee ^ "(" ^ String.concat ", " (List.map gen_expr args) ^ ")" + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "~=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "and" | OpOr -> "or" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "~" (* Lua 5.3+ *) + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> ".." + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (op, x) -> + (match op with + | OpNeg -> "(-" ^ gen_expr x ^ ")" + | OpNot -> "(not " ^ gen_expr x ^ ")" + | OpBitNot -> "(~" ^ gen_expr x ^ ")" + | OpRef -> gen_expr x + | OpDeref -> gen_expr x) + | ExprIf { ei_cond; ei_then; ei_else } -> + (* Lua has no expression-form if; emit a do-block-as-expression via + IIFE. *) + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "nil" in + Printf.sprintf "((function() if %s then return %s else return %s end end)())" c t f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let v = gen_expr el_value in + let body = match el_body with Some e -> gen_expr e | None -> "nil" in + Printf.sprintf "((function() local %s = %s; return %s end)())" var v body + | ExprBlock blk -> gen_block_expr blk + | ExprTuple es -> "{" ^ String.concat ", " (List.map gen_expr es) ^ "}" + | ExprTupleIndex (e, n) -> + (* Lua tables are 1-indexed; AS tuple indices are 0-based. *) + Printf.sprintf "%s[%d]" (gen_expr e) (n + 1) + | ExprArray es -> "{" ^ String.concat ", " (List.map gen_expr es) ^ "}" + | ExprIndex (a, i) -> gen_expr a ^ "[" ^ gen_expr i ^ "]" + | ExprRecord { er_fields; _ } -> + let fs = List.map (fun (id, e_opt) -> + let v = match e_opt with Some e -> gen_expr e | None -> mangle id.name in + Printf.sprintf "%s = %s" (mangle id.name) v + ) er_fields in + "{ " ^ String.concat ", " fs ^ " }" + | ExprField (record, field) -> + gen_expr record ^ "." ^ mangle field.name + | ExprVariant (_ty, ctor) -> mangle ctor.name + | ExprMatch { em_scrutinee; em_arms } -> + let scrut_str = gen_expr em_scrutinee in + let rec arms = function + | [] -> "error(\"non-exhaustive match\")" + | arm :: rest -> + let cond = gen_pattern_test "__scrut" arm.ma_pat in + let bindings = gen_pattern_bindings "__scrut" arm.ma_pat in + let body = gen_expr arm.ma_body in + Printf.sprintf "if %s then %s return %s else %s end" + cond bindings body (arms rest) + in + Printf.sprintf "((function(__scrut) %s end)(%s))" (arms em_arms) scrut_str + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | _ -> "error(\"Lua backend: unsupported expression\")" + +and gen_pattern_test scrut pat = + match pat with + | PatWildcard _ | PatVar _ -> "true" + | PatLit lit -> Printf.sprintf "%s == %s" scrut (gen_lit lit) + | PatCon (id, _) -> Printf.sprintf "%s.tag == %S" scrut id.name + | PatTuple _ -> "true" (* arity match by structure, not tag *) + | PatRecord _ -> "true" + | PatAs (_, p) -> gen_pattern_test scrut p + | PatOr (p, _) -> gen_pattern_test scrut p + +and gen_pattern_bindings scrut pat = + let buf = Buffer.create 64 in + let rec walk path = function + | PatWildcard _ | PatLit _ -> () + | PatVar id -> + Buffer.add_string buf + (Printf.sprintf "local %s = %s; " (mangle id.name) path) + | PatTuple ps -> + List.iteri (fun i p -> walk (Printf.sprintf "%s[%d]" path (i + 1)) p) ps + | PatRecord (fields, _) -> + List.iter (fun (id, sub) -> + let p = path ^ "." ^ mangle id.name in + match sub with + | None -> Buffer.add_string buf (Printf.sprintf "local %s = %s; " (mangle id.name) p) + | Some sub -> walk p sub + ) fields + | PatCon (_, args) -> + (match args with + | [] -> () + | [single] -> walk (path ^ ".value") single + | many -> + List.iteri (fun i p -> + walk (Printf.sprintf "%s.values[%d]" path (i + 1)) p + ) many) + | PatAs (id, sub) -> + Buffer.add_string buf (Printf.sprintf "local %s = %s; " (mangle id.name) path); + walk path sub + | PatOr (p, _) -> walk path p + in + walk scrut pat; + Buffer.contents buf + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "\"" ^ Char.escaped c ^ "\"" + | LitUnit _ -> "nil" + +and gen_block_expr (blk : block) : string = + let stmts = String.concat " " (List.map gen_stmt blk.blk_stmts) in + let tail = match blk.blk_expr with + | Some e -> "return " ^ gen_expr e + | None -> "return nil" + in + Printf.sprintf "((function() %s %s end)())" stmts tail + +and gen_stmt = function + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + Printf.sprintf "local %s = %s;" (mangle id.name) (gen_expr sl_value) + | StmtLet _ -> "" + | StmtExpr e -> gen_expr e ^ ";" + | StmtAssign (lhs, _, rhs) -> + Printf.sprintf "%s = %s;" (gen_expr lhs) (gen_expr rhs) + | StmtWhile (cond, body) -> + Printf.sprintf "while %s do %s end" + (gen_expr cond) + (String.concat " " (List.map gen_stmt body.blk_stmts)) + | StmtFor _ -> "" + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let params = String.concat ", " + (List.map (fun (p : param) -> mangle p.p_name.name) fd.fd_params) in + let body = match fd.fd_body with + | FnExpr e -> "return " ^ gen_expr e + | FnBlock b -> + let stmts = String.concat " " (List.map gen_stmt b.blk_stmts) in + let tail = match b.blk_expr with + | Some e -> "return " ^ gen_expr e + | None -> "" + in + stmts ^ " " ^ tail + in + Printf.sprintf "function %s(%s)\n %s\nend\n" name params body + +let gen_type_decl (td : type_decl) : string = + match td.td_body with + | TyEnum variants -> + let buf = Buffer.create 64 in + List.iter (fun (vd : variant_decl) -> + let arity = List.length vd.vd_fields in + let name = mangle vd.vd_name.name in + if arity = 0 then + Buffer.add_string buf (Printf.sprintf "%s = { tag = %S }\n" name vd.vd_name.name) + else if arity = 1 then + Buffer.add_string buf + (Printf.sprintf "%s = function(v) return { tag = %S, value = v } end\n" + name vd.vd_name.name) + else + let ps = List.init arity (fun i -> "v" ^ string_of_int i) in + Buffer.add_string buf + (Printf.sprintf "%s = function(%s) return { tag = %S, values = {%s} } end\n" + name (String.concat ", " ps) vd.vd_name.name (String.concat ", " ps)) + ) variants; + Buffer.add_char buf '\n'; + Buffer.contents buf + | _ -> "" (* records erased — Lua tables are duck-typed *) + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "-- Generated by AffineScript compiler\n"; + Buffer.add_string buf "-- SPDX-License-Identifier: PMPL-1.0-or-later\n"; + Buffer.add_string buf prelude; + List.iter (function + | TopType td -> Buffer.add_string buf (gen_type_decl td) + | _ -> () + ) program.prog_decls; + List.iter (function + | TopFn fd -> Buffer.add_string buf (gen_function fd); Buffer.add_char buf '\n' + | _ -> () + ) program.prog_decls; + let has_main = List.exists (function + | TopFn fd -> fd.fd_name.name = "main" + | _ -> false) program.prog_decls in + if has_main then Buffer.add_string buf "print(main())\n"; + Buffer.contents buf + +let codegen_lua (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("Lua codegen error: " ^ msg) + | e -> Error ("Lua codegen error: " ^ Printexc.to_string e) diff --git a/lib/metal_codegen.ml b/lib/metal_codegen.ml new file mode 100644 index 0000000..b3c2a52 --- /dev/null +++ b/lib/metal_codegen.ml @@ -0,0 +1,147 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Metal Shading Language emitter (MVP). + + Same kernel-sublanguage shape as the WGSL backend; output validates + with Apple's [xcrun metal] toolchain on macOS (not exercised here). *) + +open Ast + +exception Metal_unsupported of string +let unsupported m = raise (Metal_unsupported m) + +let scalar_of_type_name = function + | "Int" -> "int" + | "Float" -> "float" + | "Bool" -> "bool" + | n -> unsupported ("type not allowed in Metal kernel: " ^ n) + +let rec scalar_of (te : type_expr) : string = + match te with + | TyCon id -> scalar_of_type_name id.name + | TyOwn t | TyRef t | TyMut t -> scalar_of t + | _ -> unsupported "complex type not allowed in Metal kernel" + +let array_element (te : type_expr) : string = + let rec strip = function + | TyOwn t | TyRef t | TyMut t -> strip t + | t -> t + in + match strip te with + | TyApp (id, [TyArg inner]) when id.name = "Array" -> scalar_of inner + | _ -> unsupported "expected Array[Int|Float] for kernel buffer" + +let access_qual = function + | Some Mut -> "device" (* read+write storage *) + | _ -> "constant" (* read-only *) + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> id.name + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> unsupported "concat not supported in Metal" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (OpNeg, x) -> "(-" ^ gen_expr x ^ ")" + | ExprUnary (OpNot, x) -> "(!" ^ gen_expr x ^ ")" + | ExprUnary (OpBitNot, x) -> "(~" ^ gen_expr x ^ ")" + | ExprUnary _ -> unsupported "unary op not supported in Metal kernel" + | ExprIf { ei_cond; ei_then; ei_else } -> + let f = match ei_else with Some e -> gen_expr e | None -> "0" in + Printf.sprintf "(%s ? %s : %s)" (gen_expr ei_cond) (gen_expr ei_then) f + | ExprIndex (a, i) -> Printf.sprintf "%s[%s]" (gen_expr a) (gen_expr i) + | ExprApp (callee, args) -> + let name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect call" + in + let known = ["sin"; "cos"; "tan"; "sqrt"; "exp"; "log"; + "abs"; "floor"; "ceil"; "min"; "max"; "tanh"] in + if not (List.mem name known) then + unsupported ("call to non-builtin in Metal kernel: " ^ name); + Printf.sprintf "metal::%s(%s)" name + (String.concat ", " (List.map gen_expr args)) + | ExprSpan (inner, _) -> gen_expr inner + | _ -> unsupported "expression form not supported in Metal kernel" + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + let s = if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s in + s ^ "f" + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | _ -> unsupported "literal form not supported" + +let rec gen_stmt (s : stmt) : string = + match s with + | StmtLet { sl_pat = PatVar id; sl_value; sl_ty; _ } -> + let ty = match sl_ty with Some t -> scalar_of t | None -> "int" in + Printf.sprintf "%s %s = %s;" ty id.name (gen_expr sl_value) + | StmtLet _ -> unsupported "destructuring let" + | StmtAssign (lhs, op, rhs) -> + let s = match op with + | AssignEq -> "=" | AssignAdd -> "+=" | AssignSub -> "-=" + | AssignMul -> "*=" | AssignDiv -> "/=" in + Printf.sprintf "%s %s %s;" (gen_expr lhs) s (gen_expr rhs) + | StmtExpr e -> gen_expr e ^ ";" + | StmtWhile (c, b) -> + Printf.sprintf "while (%s) { %s }" (gen_expr c) + (String.concat " " (List.map gen_stmt b.blk_stmts)) + | StmtFor _ -> unsupported "for-in" + +let pick_kernel (program : program) : fn_decl = + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) program.prog_decls in + match List.find_opt (fun fd -> fd.fd_name.name = "kernel") fns with + | Some fd -> fd + | None -> match fns with + | fd :: _ -> fd + | [] -> unsupported "no function found" + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler (Metal MSL)\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + Buffer.add_string buf "#include \nusing namespace metal;\n\n"; + let fd = pick_kernel program in + let idx = match fd.fd_params with first :: _ -> first.p_name.name | _ -> "i" in + let bufs = match fd.fd_params with _ :: rest -> rest | [] -> [] in + let buf_decls = List.mapi (fun i (p : param) -> + Printf.sprintf "%s %s *%s [[buffer(%d)]]" + (access_qual p.p_ownership) (array_element p.p_ty) p.p_name.name i + ) bufs in + let all_params = + buf_decls @ [Printf.sprintf "uint __gid [[thread_position_in_grid]]"] + in + Buffer.add_string buf + (Printf.sprintf "kernel void %s(\n %s\n) {\n" fd.fd_name.name + (String.concat ",\n " all_params)); + Buffer.add_string buf + (Printf.sprintf " int %s = (int)__gid;\n" idx); + (match fd.fd_body with + | FnExpr e -> + Buffer.add_string buf (Printf.sprintf " (void)(%s);\n" (gen_expr e)) + | FnBlock b -> + List.iter (fun s -> Buffer.add_string buf (" " ^ gen_stmt s ^ "\n")) b.blk_stmts; + (match b.blk_expr with + | Some e -> Buffer.add_string buf (Printf.sprintf " (void)(%s);\n" (gen_expr e)) + | None -> ())); + Buffer.add_string buf "}\n"; + Buffer.contents buf + +let codegen_metal (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Metal_unsupported m -> Error ("Metal backend: " ^ m) + | Failure m -> Error ("Metal codegen error: " ^ m) + | e -> Error ("Metal codegen error: " ^ Printexc.to_string e) diff --git a/lib/mlir_codegen.ml b/lib/mlir_codegen.ml new file mode 100644 index 0000000..6db4ade --- /dev/null +++ b/lib/mlir_codegen.ml @@ -0,0 +1,191 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** MLIR text emitter (MVP, [func]/[arith] dialects). + + Lowers Int/Float scalar functions to MLIR's [func.func] + + [arith.{addi,subi,muli,addf,subf,mulf,...}]. Output is consumable by + [mlir-opt] and downstream pipelines (StableHLO, IREE, the Linalg + stack). + + Scope: same as LLVM IR — Int (i64) and Float (f64) only, single-fn + bodies with arithmetic + branches. Tensor lowering is Phase 2. *) + +open Ast + +exception Mlir_unsupported of string +let unsupported m = raise (Mlir_unsupported m) + +type fstate = { + body : Buffer.t; + mutable next_ssa : int; + mutable env : (string * (string * string)) list; (* var -> (ssa, ty) *) +} + +let new_fstate () = { body = Buffer.create 256; next_ssa = 0; env = [] } +let fresh st = let n = st.next_ssa in st.next_ssa <- n + 1; Printf.sprintf "%%%d" n +let emit st s = Buffer.add_string st.body s; Buffer.add_char st.body '\n' +let bind st name ssa ty = st.env <- (name, (ssa, ty)) :: st.env +let lookup st name = + try List.assoc name st.env + with Not_found -> unsupported ("unbound: " ^ name) + +let mlir_type = function + | TyCon id when id.name = "Int" -> "i64" + | TyCon id when id.name = "Float" -> "f64" + | TyCon id when id.name = "Bool" -> "i1" + | _ -> unsupported "type not supported in MLIR backend" + +let ret_type = function None -> "()" | Some t -> mlir_type t + +let rec gen_expr (st : fstate) (e : expr) : string * string = + match e with + | ExprLit (LitInt (n, _)) -> + let dst = fresh st in + emit st (Printf.sprintf " %s = arith.constant %d : i64" dst n); + (dst, "i64") + | ExprLit (LitFloat (f, _)) -> + let s = string_of_float f in + let s = if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s in + let dst = fresh st in + emit st (Printf.sprintf " %s = arith.constant %s : f64" dst s); + (dst, "f64") + | ExprLit (LitBool (b, _)) -> + let dst = fresh st in + emit st (Printf.sprintf " %s = arith.constant %d : i1" dst (if b then 1 else 0)); + (dst, "i1") + | ExprLit _ -> unsupported "non-numeric literal" + | ExprVar id -> lookup st id.name + | ExprBinary (a, op, b) -> + let (av, ty) = gen_expr st a in + let (bv, _) = gen_expr st b in + let opname = match op, ty with + | OpAdd, "i64" -> "arith.addi" | OpAdd, "f64" -> "arith.addf" + | OpSub, "i64" -> "arith.subi" | OpSub, "f64" -> "arith.subf" + | OpMul, "i64" -> "arith.muli" | OpMul, "f64" -> "arith.mulf" + | OpDiv, "i64" -> "arith.divsi"| OpDiv, "f64" -> "arith.divf" + | OpMod, "i64" -> "arith.remsi"| OpMod, "f64" -> "arith.remf" + | OpBitAnd, _ -> "arith.andi" + | OpBitOr, _ -> "arith.ori" + | OpBitXor, _ -> "arith.xori" + | OpShl, _ -> "arith.shli" + | OpShr, _ -> "arith.shrsi" + | _ -> unsupported "binop / comparison needs different lowering" + in + let dst = fresh st in + emit st (Printf.sprintf " %s = %s %s, %s : %s" dst opname av bv ty); + (dst, ty) + | ExprUnary (OpNeg, x) -> + let (xv, ty) = gen_expr st x in + let dst = fresh st in + (match ty with + | "i64" -> + let zero = fresh st in + emit st (Printf.sprintf " %s = arith.constant 0 : i64" zero); + emit st (Printf.sprintf " %s = arith.subi %s, %s : i64" dst zero xv) + | "f64" -> + emit st (Printf.sprintf " %s = arith.negf %s : f64" dst xv) + | _ -> unsupported "negate on non-numeric"); + (dst, ty) + | ExprIf { ei_cond; ei_then; ei_else } -> + let (cv, _) = gen_cmp st ei_cond in + let (tv, ty) = gen_expr st ei_then in + let (ev, _) = match ei_else with + | Some e -> gen_expr st e + | None -> unsupported "if without else has no value" + in + let dst = fresh st in + emit st (Printf.sprintf " %s = arith.select %s, %s, %s : %s" dst cv tv ev ty); + (dst, ty) + | ExprBlock blk -> + List.iter (fun s -> + match s with + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + let (v, ty) = gen_expr st sl_value in bind st id.name v ty + | StmtExpr e -> ignore (gen_expr st e) + | _ -> unsupported "stmt form not supported in MLIR block" + ) blk.blk_stmts; + (match blk.blk_expr with + | Some e -> gen_expr st e + | None -> unsupported "block must end with an expression") + | ExprLet { el_pat = PatVar id; el_value; el_body = Some body; _ } -> + let (v, ty) = gen_expr st el_value in + bind st id.name v ty; + gen_expr st body + | ExprApp (callee, args) -> + let name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect call" + in + let arg_pairs = List.map (gen_expr st) args in + let arg_str = String.concat ", " (List.map fst arg_pairs) in + let arg_tys = String.concat ", " (List.map snd arg_pairs) in + let dst = fresh st in + let ret_ty = "i64" in (* MVP: conservative *) + emit st (Printf.sprintf " %s = func.call @%s(%s) : (%s) -> %s" + dst name arg_str arg_tys ret_ty); + (dst, ret_ty) + | ExprSpan (inner, _) -> gen_expr st inner + | _ -> unsupported "expression form not supported in MLIR backend" + +and gen_cmp (st : fstate) (e : expr) : string * string = + match e with + | ExprBinary (a, (OpEq|OpNe|OpLt|OpLe|OpGt|OpGe as op), b) -> + let (av, ty) = gen_expr st a in + let (bv, _) = gen_expr st b in + let pred, opname = match op, ty with + | OpEq, "i64" -> "eq", "arith.cmpi" | OpNe, "i64" -> "ne", "arith.cmpi" + | OpLt, "i64" -> "slt", "arith.cmpi" | OpLe, "i64" -> "sle", "arith.cmpi" + | OpGt, "i64" -> "sgt", "arith.cmpi" | OpGe, "i64" -> "sge", "arith.cmpi" + | OpEq, "f64" -> "oeq", "arith.cmpf" | OpNe, "f64" -> "one", "arith.cmpf" + | OpLt, "f64" -> "olt", "arith.cmpf" | OpLe, "f64" -> "ole", "arith.cmpf" + | OpGt, "f64" -> "ogt", "arith.cmpf" | OpGe, "f64" -> "oge", "arith.cmpf" + | _ -> unsupported "comparison on non-numeric" + in + let dst = fresh st in + emit st (Printf.sprintf " %s = %s %s, %s, %s : %s" dst opname pred av bv ty); + (dst, "i1") + | _ -> gen_expr st e + +let gen_function (buf : Buffer.t) (fd : fn_decl) : unit = + let st = new_fstate () in + List.iter (fun (p : param) -> + let ty = mlir_type p.p_ty in + bind st p.p_name.name ("%" ^ p.p_name.name) ty; + st.next_ssa <- max st.next_ssa 0 + ) fd.fd_params; + let body_expr = match fd.fd_body with + | FnExpr e -> e + | FnBlock b -> ExprBlock b + in + let ret_ty = ret_type fd.fd_ret_ty in + let (rv, _) = gen_expr st body_expr in + if ret_ty = "()" then emit st " func.return" + else emit st (Printf.sprintf " func.return %s : %s" rv ret_ty); + let params_str = String.concat ", " + (List.map (fun (p : param) -> + Printf.sprintf "%%%s: %s" p.p_name.name (mlir_type p.p_ty)) + fd.fd_params) in + Buffer.add_string buf + (Printf.sprintf "func.func @%s(%s) -> %s {\n" fd.fd_name.name params_str ret_ty); + Buffer.add_buffer buf st.body; + Buffer.add_string buf "}\n\n" + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler (MLIR func/arith)\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n"; + Buffer.add_string buf "module {\n"; + List.iter (function + | TopFn fd -> gen_function buf fd + | _ -> () + ) program.prog_decls; + Buffer.add_string buf "}\n"; + Buffer.contents buf + +let codegen_mlir (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Mlir_unsupported m -> Error ("MLIR backend: " ^ m) + | Failure m -> Error ("MLIR codegen error: " ^ m) + | e -> Error ("MLIR codegen error: " ^ Printexc.to_string e) diff --git a/lib/nickel_codegen.ml b/lib/nickel_codegen.ml new file mode 100644 index 0000000..06735ec --- /dev/null +++ b/lib/nickel_codegen.ml @@ -0,0 +1,125 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Nickel emitter (config-language target). + + Lowers a subset of AffineScript to Nickel, which is a typed + configuration language with first-class records and merging. + + Nickel is *config*, not a general-purpose target — control flow is + expressions only, no loops, no mutation. This emitter handles the + natural overlap: pure functions over Int/Float/Bool/String/Record. *) + +open Ast + +let nickel_reserved = [ + "if"; "then"; "else"; "let"; "in"; "fun"; "match"; "true"; "false"; "null"; + "import"; "merge"; "default"; "force"; "doc"; "Number"; "String"; "Bool"; + "Array"; "Dyn"; "switch"; +] + +let mangle s = if List.mem s nickel_reserved then s ^ "_" else s + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + List.fold_left (fun acc a -> "(" ^ acc ^ " " ^ gen_expr a ^ ")") + (gen_expr callee) args + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpConcat -> "++" + | OpBitAnd | OpBitOr | OpBitXor | OpShl | OpShr -> + failwith "Nickel backend: bitwise ops not supported" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (OpNeg, x) -> "(-" ^ gen_expr x ^ ")" + | ExprUnary (OpNot, x) -> "(!" ^ gen_expr x ^ ")" + | ExprUnary _ -> failwith "Nickel backend: unsupported unary op" + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "null" in + Printf.sprintf "(if %s then %s else %s)" c t f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let v = gen_expr el_value in + let body = match el_body with Some e -> gen_expr e | None -> "null" in + Printf.sprintf "(let %s = %s in %s)" var v body + | ExprBlock blk -> gen_block blk + | ExprRecord { er_fields; _ } -> + let f = List.map (fun (id, e_opt) -> + let v = match e_opt with Some e -> gen_expr e | None -> mangle id.name in + Printf.sprintf "%s = %s" (mangle id.name) v + ) er_fields in + "{ " ^ String.concat ", " f ^ " }" + | ExprField (r, f) -> gen_expr r ^ "." ^ mangle f.name + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | _ -> "(error \"Nickel backend: unsupported expression\")" + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "\"" ^ Char.escaped c ^ "\"" + | LitUnit _ -> "null" + +and gen_block (blk : block) : string = + (* Nickel has no statements — only let-chains. Fold StmtLet into nested + let..in expressions; everything else is unsupported. *) + let rec fold = function + | [] -> + (match blk.blk_expr with + | Some e -> gen_expr e + | None -> "null") + | StmtLet { sl_pat = PatVar id; sl_value; _ } :: rest -> + Printf.sprintf "(let %s = %s in %s)" (mangle id.name) + (gen_expr sl_value) (fold rest) + | _ :: _ -> "(error \"Nickel: only let-bindings allowed in a block\")" + in + fold blk.blk_stmts + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + match fd.fd_params with + | [] -> Printf.sprintf " %s = %s,\n" name body + | _ -> + let params = String.concat " => " + (List.map (fun (p : param) -> "fun " ^ mangle p.p_name.name) fd.fd_params) + ^ " => " in + Printf.sprintf " %s = %s%s,\n" name params body + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "# Generated by AffineScript compiler\n"; + Buffer.add_string buf "# SPDX-License-Identifier: PMPL-1.0-or-later\n"; + Buffer.add_string buf "{\n"; + List.iter (function + | TopFn fd -> Buffer.add_string buf (gen_function fd) + | TopConst { tc_name; tc_value; _ } -> + Buffer.add_string buf + (Printf.sprintf " %s = %s,\n" (mangle tc_name.name) (gen_expr tc_value)) + | _ -> () + ) program.prog_decls; + Buffer.add_string buf "}\n"; + Buffer.contents buf + +let codegen_nickel (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("Nickel codegen error: " ^ msg) + | e -> Error ("Nickel codegen error: " ^ Printexc.to_string e) diff --git a/lib/ocaml_codegen.ml b/lib/ocaml_codegen.ml new file mode 100644 index 0000000..8a645a7 --- /dev/null +++ b/lib/ocaml_codegen.ml @@ -0,0 +1,250 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** OCaml Self-Target Emitter (MVP). + + Lowers a subset of AffineScript to compilable OCaml source. Useful as + a self-validation target — we *are* OCaml, so this is effectively a + pretty-printer with renaming. *) + +open Ast + +let ocaml_reserved = [ + "and"; "as"; "assert"; "begin"; "class"; "constraint"; "do"; "done"; + "downto"; "else"; "end"; "exception"; "external"; "false"; "for"; "fun"; + "function"; "functor"; "if"; "in"; "include"; "inherit"; "initializer"; + "lazy"; "let"; "match"; "method"; "module"; "mutable"; "new"; "nonrec"; + "object"; "of"; "open"; "or"; "private"; "rec"; "sig"; "struct"; "then"; + "to"; "true"; "try"; "type"; "val"; "virtual"; "when"; "while"; "with"; +] + +let mangle s = + if List.mem s ocaml_reserved then s ^ "_" else s + +(* OCaml type names must start lowercase; AS uses TitleCase by convention. + Lowercase the first letter and dodge reserved words by prefixing. *) +let mangle_ty s = + if String.length s = 0 then s + else + let lowered = String.uncapitalize_ascii s in + if List.mem lowered ocaml_reserved then "_" ^ lowered else lowered + +let rec ml_type = function + | TyCon id when id.name = "Int" -> "int" + | TyCon id when id.name = "Float" -> "float" + | TyCon id when id.name = "Bool" -> "bool" + | TyCon id when id.name = "String" -> "string" + | TyCon id when id.name = "Unit" -> "unit" + | TyCon id -> mangle_ty id.name + | TyTuple [] -> "unit" + | TyTuple ts -> "(" ^ String.concat " * " (List.map ml_type ts) ^ ")" + | TyOwn t | TyRef t | TyMut t -> ml_type t + | _ -> "_" + +let ret_type = function + | None -> "unit" + | Some t -> ml_type t + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + let f = gen_expr callee in + let xs = List.map (fun a -> "(" ^ gen_expr a ^ ")") args in + f ^ " " ^ String.concat " " xs + | ExprBinary (a, op, b) -> + let opstr = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" + | OpMod -> "mod" | OpEq -> "=" | OpNe -> "<>" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "land" | OpBitOr -> "lor" | OpBitXor -> "lxor" + | OpShl -> "lsl" | OpShr -> "lsr" + | OpConcat -> "^" + in + "(" ^ gen_expr a ^ " " ^ opstr ^ " " ^ gen_expr b ^ ")" + | ExprUnary (op, x) -> + (match op with + | OpNeg -> "(- " ^ gen_expr x ^ ")" + | OpNot -> "(not " ^ gen_expr x ^ ")" + | OpBitNot -> "(lnot " ^ gen_expr x ^ ")" + | OpRef -> "(ref " ^ gen_expr x ^ ")" + | OpDeref -> "(!" ^ gen_expr x ^ ")") + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(if %s then %s else %s)" c t f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let v = gen_expr el_value in + let body = match el_body with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(let %s = %s in %s)" var v body + | ExprBlock blk -> gen_block blk + | ExprTuple es -> "(" ^ String.concat ", " (List.map gen_expr es) ^ ")" + | ExprTupleIndex (e, n) -> + (* OCaml has no .N tuple projection. Use fst/snd for the binary case + (the common one); for higher arities, the codegen would need the + tuple's static type to know how many wildcards to emit. We don't + track it through the AST yet, so error loudly on n > 1. *) + (match n with + | 0 -> Printf.sprintf "(fst %s)" (gen_expr e) + | 1 -> Printf.sprintf "(snd %s)" (gen_expr e) + | _ -> + Printf.sprintf "(failwith \"OCaml backend: tuple index %d (need static arity)\")" n) + | ExprArray es -> "[|" ^ String.concat "; " (List.map gen_expr es) ^ "|]" + | ExprIndex (a, i) -> + Printf.sprintf "%s.(%s)" (gen_expr a) (gen_expr i) + | ExprRecord { er_fields; _ } -> + let fs = List.map (fun (id, e_opt) -> + let v = match e_opt with Some e -> gen_expr e | None -> mangle id.name in + Printf.sprintf "%s = %s" (mangle id.name) v + ) er_fields in + "{ " ^ String.concat "; " fs ^ " }" + | ExprField (record, field) -> + gen_expr record ^ "." ^ mangle field.name + | ExprVariant (_ty, ctor) -> mangle ctor.name + | ExprMatch { em_scrutinee; em_arms } -> + let arms = List.map (fun arm -> + Printf.sprintf "| %s -> %s" (gen_pattern arm.ma_pat) (gen_expr arm.ma_body) + ) em_arms in + Printf.sprintf "(match %s with %s)" (gen_expr em_scrutinee) + (String.concat " " arms) + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | _ -> "(failwith \"OCaml backend: unsupported expression\")" + +and gen_pattern (p : pattern) : string = + match p with + | PatWildcard _ -> "_" + | PatVar id -> mangle id.name + | PatLit lit -> gen_literal lit + | PatCon (id, args) -> + if args = [] then mangle id.name + else + let aps = List.map gen_pattern args in + if List.length aps = 1 then mangle id.name ^ " " ^ List.hd aps + else mangle id.name ^ " (" ^ String.concat ", " aps ^ ")" + | PatTuple ps -> "(" ^ String.concat ", " (List.map gen_pattern ps) ^ ")" + | PatRecord (fields, _) -> + let fs = List.map (fun (id, sub) -> + match sub with + | None -> mangle id.name + | Some sub -> Printf.sprintf "%s = %s" (mangle id.name) (gen_pattern sub) + ) fields in + "{ " ^ String.concat "; " fs ^ " }" + | PatAs (id, _) -> mangle id.name + | PatOr (p, _) -> gen_pattern p + +and gen_literal lit = + match lit with + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "'" ^ Char.escaped c ^ "'" + | LitUnit _ -> "()" + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "'" ^ Char.escaped c ^ "'" + | LitUnit _ -> "()" + +and gen_block (blk : block) : string = + let stmts = List.map gen_stmt blk.blk_stmts in + let tail = match blk.blk_expr with Some e -> gen_expr e | None -> "()" in + match stmts with + | [] -> tail + | _ -> "(" ^ String.concat "; " stmts ^ "; " ^ tail ^ ")" + +and gen_stmt = function + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + Printf.sprintf "let %s = %s in ()" (mangle id.name) (gen_expr sl_value) + (* Note: this fragment is then sequenced by gen_block. For trivial MVP + we approximate with `let _ = ... in ()` style sequencing, which + compiles even though it's not idiomatic. *) + | StmtLet _ -> "()" + | StmtExpr e -> gen_expr e + | StmtAssign (lhs, _, rhs) -> + Printf.sprintf "%s := %s" (gen_expr lhs) (gen_expr rhs) + | StmtWhile (cond, body) -> + Printf.sprintf "while %s do %s done" (gen_expr cond) (gen_block body) + | StmtFor _ -> "()" + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let params = match fd.fd_params with + | [] -> "()" + | _ -> String.concat " " (List.map + (fun (p : param) -> "(" ^ mangle p.p_name.name ^ " : " ^ ml_type p.p_ty ^ ")") + fd.fd_params) + in + let ret = ret_type fd.fd_ret_ty in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + Printf.sprintf "let %s %s : %s =\n %s\n" name params ret body + +let gen_type_decl (td : type_decl) : string = + let name = mangle_ty td.td_name.name in + match td.td_body with + | TyAlias (TyRecord (fields, _)) -> + let fs = List.map (fun (rf : row_field) -> + Printf.sprintf " %s : %s" (mangle rf.rf_name.name) (ml_type rf.rf_ty) + ) fields in + Printf.sprintf "type %s = {\n%s\n}\n\n" name (String.concat ";\n" fs) + | TyAlias t -> Printf.sprintf "type %s = %s\n\n" name (ml_type t) + | TyStruct fields -> + let fs = List.map (fun (sf : struct_field) -> + Printf.sprintf " %s : %s" (mangle sf.sf_name.name) (ml_type sf.sf_ty) + ) fields in + Printf.sprintf "type %s = {\n%s\n}\n\n" name (String.concat ";\n" fs) + | TyEnum variants -> + let vs = List.map (fun (vd : variant_decl) -> + let tys = List.map ml_type vd.vd_fields in + let body = if tys = [] then "" else " of " ^ String.concat " * " tys in + Printf.sprintf " | %s%s" (mangle vd.vd_name.name) body + ) variants in + Printf.sprintf "type %s =\n%s\n\n" name (String.concat "\n" vs) + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "(* Generated by AffineScript compiler *)\n"; + Buffer.add_string buf "(* SPDX-License-Identifier: PMPL-1.0-or-later *)\n\n"; + (* Type decls precede functions so the typechecker sees the schema. *) + List.iter (function + | TopType td -> Buffer.add_string buf (gen_type_decl td) + | _ -> () + ) program.prog_decls; + List.iter (fun decl -> + match decl with + | TopFn fd -> Buffer.add_string buf (gen_function fd); Buffer.add_char buf '\n' + | TopConst { tc_name; tc_ty; tc_value; _ } -> + Buffer.add_string buf + (Printf.sprintf "let %s : %s = %s\n\n" + (mangle tc_name.name) (ml_type tc_ty) (gen_expr tc_value)) + | _ -> () + ) program.prog_decls; + let has_main = List.exists (function + | TopFn fd -> fd.fd_name.name = "main" + | _ -> false) program.prog_decls in + if has_main then + Buffer.add_string buf "let () = ignore (main ())\n"; + Buffer.contents buf + +let codegen_ocaml (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("OCaml codegen error: " ^ msg) + | e -> Error ("OCaml codegen error: " ^ Printexc.to_string e) diff --git a/lib/onnx_codegen.ml b/lib/onnx_codegen.ml new file mode 100644 index 0000000..12f1d1b --- /dev/null +++ b/lib/onnx_codegen.ml @@ -0,0 +1,257 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** ONNX Backend (MVP — proof of wire format). + + Emits a binary [.onnx] [ModelProto] from a strict subset of AffineScript: + + - one entry function (named [graph], [main], or the first [fn]) + - parameters typed [Array[Float]] are graph inputs (1-D, dimension param) + - return type [Array[Float]] is the graph output + - body is a chain of [let] bindings whose RHS is one of: + * a variable reference (passes the SSA value forward) + * a call to a recognised ONNX op by name + and a final expression naming the output + + The recognised op set is small: arithmetic ([Add], [Sub], [Mul], [Div]), + common activations ([Relu], [Sigmoid], [Tanh]), and identity-shaped ops + ([Identity]). Each is matched by AffineScript function name (case- + insensitive lowercased — [add] maps to [Add], etc.). Unrecognised calls + produce a hard error so the regression is loud. + + What this does NOT do (intentional, MVP scope): + - tensor shape inference + - broadcasting rules + - attribute encoding (axes, alpha, etc.) + - initializer tensors / weights + - any verification beyond "the bytes parse as ONNX" + + Validation strategy: round-trip the output through any ONNX reader + (oxionnx-proto, onnxruntime, tract). All three accept the same wire + format; if the bytes decode they're ONNX-conformant. +*) + +open Ast + +exception Onnx_unsupported of string +let unsupported msg = raise (Onnx_unsupported msg) + +(* ============================================================================ + Op recognition + + Map an AffineScript function name to (ONNX op type, expected arity). + ============================================================================ *) + +let recognise_op (name : string) : (string * int) option = + match String.lowercase_ascii name with + | "add" -> Some ("Add", 2) + | "sub" -> Some ("Sub", 2) + | "mul" -> Some ("Mul", 2) + | "div" -> Some ("Div", 2) + | "relu" -> Some ("Relu", 1) + | "sigmoid" -> Some ("Sigmoid", 1) + | "tanh" -> Some ("Tanh", 1) + | "neg" -> Some ("Neg", 1) + | "abs" -> Some ("Abs", 1) + | "exp" -> Some ("Exp", 1) + | "log" -> Some ("Log", 1) + | "sqrt" -> Some ("Sqrt", 1) + | "identity" -> Some ("Identity", 1) + | _ -> None + +(* ============================================================================ + Type validation + + Graph inputs and outputs must be [Array[Float]]. We accept the AffineScript + surface forms [Array[Float]], [ref Array[Float]], [mut Array[Float]]. + ============================================================================ *) + +let rec strip_ownership = function + | TyOwn t | TyRef t | TyMut t -> strip_ownership t + | t -> t + +let is_array_float (te : type_expr) : bool = + match strip_ownership te with + | TyApp (id, [TyArg (TyCon e)]) when id.name = "Array" && e.name = "Float" -> true + | _ -> false + +(* ============================================================================ + Build the ONNX graph from an AffineScript function body + + We lower a chain of [let v = call(...)] bindings into a list of ONNX + nodes. The graph inputs are the function parameters, the output is the + final expression (which must be a variable reference). + ============================================================================ *) + +type build_state = { + mutable nodes : Onnx_proto.node list; + mutable next_id : int; +} + +let fresh_node_name (st : build_state) (op : string) : string = + let id = st.next_id in + st.next_id <- id + 1; + Printf.sprintf "%s_%d" op id + +(** Lower a value-producing expression into either a variable name (already + in scope) or a new node whose output we name and return. *) +let rec lower_expr (st : build_state) (e : expr) : string = + match e with + | ExprVar id -> id.name + | ExprSpan (inner, _) -> lower_expr st inner + | ExprApp (callee, args) -> + let fn_name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect calls not supported in ONNX backend" + in + let (op_type, expected_arity) = match recognise_op fn_name with + | Some pair -> pair + | None -> unsupported ("unknown ONNX op (no name match): " ^ fn_name) + in + let actual = List.length args in + if actual <> expected_arity then + unsupported + (Printf.sprintf "%s expects %d args, got %d" op_type expected_arity actual); + let arg_names = List.map (lower_expr st) args in + let out_name = fresh_node_name st op_type in + let node = { + Onnx_proto.n_input = arg_names; + n_output = [out_name]; + n_name = out_name; + n_op_type = op_type; + n_domain = ""; + } in + st.nodes <- node :: st.nodes; + out_name + | ExprLet { el_pat; el_value; el_body; el_mut = _; el_quantity = _; el_ty = _ } -> + (* Ignore the binder for the SSA-style lowering: each call already + produces a uniquely-named output. We propagate the *output of the + RHS* to the body, threading variables through scope by their + user-visible name. *) + let rhs_name = lower_expr st el_value in + let var = match el_pat with + | PatVar id -> id.name + | _ -> unsupported "destructuring let not supported in ONNX backend" + in + (* Rename the latest node's output to the bound name so the user's + identifier survives into the graph. Only safe when the RHS produced + a fresh node (not a passthrough variable). *) + (match st.nodes with + | latest :: rest when latest.n_output = [rhs_name] -> + let renamed = { latest with + Onnx_proto.n_output = [var]; + n_name = var; + } in + st.nodes <- renamed :: rest + | _ -> + (* RHS was a variable — emit an Identity node so the alias is + real in the graph. *) + let node = { + Onnx_proto.n_input = [rhs_name]; + n_output = [var]; + n_name = var; + n_op_type = "Identity"; + n_domain = ""; + } in + st.nodes <- node :: st.nodes); + (match el_body with + | Some body -> lower_expr st body + | None -> unsupported "let without body cannot produce graph output") + | ExprBlock blk -> + List.fold_left (fun _last s -> + match s with + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + let rhs_name = lower_expr st sl_value in + (match st.nodes with + | latest :: rest when latest.n_output = [rhs_name] -> + st.nodes <- { latest with + Onnx_proto.n_output = [id.name]; n_name = id.name } :: rest + | _ -> + st.nodes <- { + Onnx_proto.n_input = [rhs_name]; n_output = [id.name]; + n_name = id.name; n_op_type = "Identity"; n_domain = ""; + } :: st.nodes); + id.name + | StmtExpr e -> lower_expr st e + | _ -> unsupported "only let-bindings and trailing expressions allowed in ONNX block" + ) "" blk.blk_stmts + |> ignore; + (match blk.blk_expr with + | Some e -> lower_expr st e + | None -> unsupported "block must end with an expression naming the output") + | ExprLit _ -> + unsupported "literal not supported as ONNX value (need Constant op + initializer)" + | _ -> + unsupported "expression form not supported in ONNX kernel" + +(* ============================================================================ + Driver + ============================================================================ *) + +let pick_entry (program : program) : fn_decl = + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) + program.prog_decls in + let by_name n = List.find_opt (fun fd -> fd.fd_name.name = n) fns in + match by_name "graph" with + | Some fd -> fd + | None -> + match by_name "main" with + | Some fd -> fd + | None -> + match fns with + | fd :: _ -> fd + | [] -> unsupported "no function found to lower as ONNX graph" + +let validate_entry (fd : fn_decl) : unit = + List.iter (fun (p : param) -> + if not (is_array_float p.p_ty) then + unsupported + (Printf.sprintf "parameter %s must be Array[Float]" p.p_name.name) + ) fd.fd_params; + match fd.fd_ret_ty with + | Some t when is_array_float t -> () + | None -> unsupported "graph function must declare a return type" + | _ -> unsupported "graph function must return Array[Float]" + +(** Produce a ValueInfoProto for a parameter or result name. We use a single + dynamic dimension named [N] so consumers can pass any-length tensors. *) +let value_info_for (name : string) : Onnx_proto.value_info = { + vi_name = name; + vi_type = Onnx_proto.TensorType { + elem_type = 1; (* FLOAT *) + shape = [Onnx_proto.DimParam "N"]; + }; +} + +let generate (program : program) (_symbols : Symbol.t) : string = + let entry = pick_entry program in + validate_entry entry; + let st = { nodes = []; next_id = 0 } in + let output_name = match entry.fd_body with + | FnExpr e -> lower_expr st e + | FnBlock b -> lower_expr st (ExprBlock b) + in + let inputs = List.map (fun (p : param) -> value_info_for p.p_name.name) + entry.fd_params in + let outputs = [value_info_for output_name] in + let graph = { + Onnx_proto.g_node = List.rev st.nodes; + g_name = entry.fd_name.name; + g_input = inputs; + g_output = outputs; + } in + let model = { + Onnx_proto.m_ir_version = 7; (* ONNX 1.10+ *) + m_producer_name = "affinescript"; + m_producer_version = "0.1.0"; + m_opset_import = [{ op_domain = ""; op_version = 13 }]; + m_graph = graph; + } in + Onnx_proto.serialize_model model + +let codegen_onnx (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Onnx_unsupported msg -> Error ("ONNX backend: " ^ msg) + | Failure msg -> Error ("ONNX codegen error: " ^ msg) + | e -> Error ("ONNX codegen error: " ^ Printexc.to_string e) diff --git a/lib/onnx_proto.ml b/lib/onnx_proto.ml new file mode 100644 index 0000000..a653dec --- /dev/null +++ b/lib/onnx_proto.ml @@ -0,0 +1,165 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** ONNX schema subset, encoded directly to protobuf wire bytes. + + Implements the messages we need to construct a valid [ModelProto] from + scratch, in the order they appear in the upstream [onnx.proto]: + + - [TensorShapeProto.Dimension] (`dim_value` or `dim_param`) + - [TensorShapeProto] + - [TypeProto.Tensor] + - [TypeProto] + - [ValueInfoProto] + - [NodeProto] + - [GraphProto] + - [OperatorSetIdProto] + - [ModelProto] + + Field numbers and wire types are taken from the canonical + [https://github.com/onnx/onnx/blob/main/onnx/onnx.proto] schema. The + ONNX runtime, ORT, tract, and oxionnx-proto all read this exact wire + format; round-tripping our output through any of them validates we got + the field numbers right. + + Tensor element-type codes match [TensorProto.DataType]: + 1 = FLOAT, 7 = INT64, 11 = DOUBLE, 9 = BOOL — see ONNX's enum table. +*) + +open Protobuf + +(* ============================================================================ + TensorShapeProto.Dimension + ============================================================================ *) + +type dimension = + | DimValue of int + | DimParam of string + +let encode_dimension (buf : Buffer.t) (d : dimension) : unit = + match d with + | DimValue n -> encode_int64_field buf 1 n + | DimParam s -> encode_string_field buf 2 s + +(* ============================================================================ + TensorShapeProto + ============================================================================ *) + +type shape = dimension list + +let encode_shape (buf : Buffer.t) (dims : shape) : unit = + encode_repeated_message_field buf 1 encode_dimension dims + +(* ============================================================================ + TypeProto.Tensor + ============================================================================ *) + +type tensor_type = { + elem_type : int; (* TensorProto.DataType *) + shape : shape; +} + +let encode_tensor_type (buf : Buffer.t) (t : tensor_type) : unit = + encode_int32_field buf 1 t.elem_type; + encode_message_field buf 2 (fun b -> encode_shape b t.shape) + +(* ============================================================================ + TypeProto + + We only emit the [tensor_type] variant — sequence/map/optional types + exist in newer ONNX but aren't relevant for the kernel sublanguage. + ============================================================================ *) + +type type_proto = + | TensorType of tensor_type + +let encode_type_proto (buf : Buffer.t) (t : type_proto) : unit = + match t with + | TensorType tt -> encode_message_field buf 1 (fun b -> encode_tensor_type b tt) + +(* ============================================================================ + ValueInfoProto + ============================================================================ *) + +type value_info = { + vi_name : string; + vi_type : type_proto; +} + +let encode_value_info (buf : Buffer.t) (v : value_info) : unit = + encode_string_field buf 1 v.vi_name; + encode_message_field buf 2 (fun b -> encode_type_proto b v.vi_type) + +(* ============================================================================ + NodeProto + ============================================================================ *) + +type node = { + n_input : string list; + n_output : string list; + n_name : string; + n_op_type : string; + n_domain : string; (* "" for default opset *) +} + +let encode_node (buf : Buffer.t) (n : node) : unit = + encode_repeated_string_field buf 1 n.n_input; + encode_repeated_string_field buf 2 n.n_output; + encode_string_field buf 3 n.n_name; + encode_string_field buf 4 n.n_op_type; + if n.n_domain <> "" then + encode_string_field buf 7 n.n_domain + +(* ============================================================================ + GraphProto + ============================================================================ *) + +type graph = { + g_node : node list; + g_name : string; + g_input : value_info list; + g_output : value_info list; +} + +let encode_graph (buf : Buffer.t) (g : graph) : unit = + encode_repeated_message_field buf 1 encode_node g.g_node; + encode_string_field buf 2 g.g_name; + encode_repeated_message_field buf 11 encode_value_info g.g_input; + encode_repeated_message_field buf 12 encode_value_info g.g_output + +(* ============================================================================ + OperatorSetIdProto + ============================================================================ *) + +type opset = { + op_domain : string; (* "" for the default ONNX opset *) + op_version : int; +} + +let encode_opset (buf : Buffer.t) (o : opset) : unit = + encode_string_field buf 1 o.op_domain; + encode_int64_field buf 2 o.op_version + +(* ============================================================================ + ModelProto + ============================================================================ *) + +type model = { + m_ir_version : int; + m_producer_name : string; + m_producer_version : string; + m_opset_import : opset list; + m_graph : graph; +} + +let encode_model (buf : Buffer.t) (m : model) : unit = + encode_int64_field buf 1 m.m_ir_version; + encode_repeated_message_field buf 8 encode_opset m.m_opset_import; + encode_message_field buf 7 (fun b -> encode_graph b m.m_graph); + encode_string_field buf 2 m.m_producer_name; + encode_string_field buf 3 m.m_producer_version + +let serialize_model (m : model) : string = + let buf = Buffer.create 1024 in + encode_model buf m; + Buffer.contents buf diff --git a/lib/opencl_codegen.ml b/lib/opencl_codegen.ml new file mode 100644 index 0000000..8a89905 --- /dev/null +++ b/lib/opencl_codegen.ml @@ -0,0 +1,133 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** OpenCL C kernel emitter (MVP). Output validates with [clang -cc1 + -triple spir64 -xcl] or any conformant OpenCL implementation. *) + +open Ast + +exception Cl_unsupported of string +let unsupported m = raise (Cl_unsupported m) + +let scalar_of_type_name = function + | "Int" -> "int" | "Float" -> "float" | "Bool" -> "bool" + | n -> unsupported ("type not allowed in OpenCL kernel: " ^ n) + +let rec scalar_of (te : type_expr) : string = + match te with + | TyCon id -> scalar_of_type_name id.name + | TyOwn t | TyRef t | TyMut t -> scalar_of t + | _ -> unsupported "complex type not allowed" + +let array_element (te : type_expr) : string = + let rec strip = function + | TyOwn t | TyRef t | TyMut t -> strip t + | t -> t + in + match strip te with + | TyApp (id, [TyArg inner]) when id.name = "Array" -> scalar_of inner + | _ -> unsupported "expected Array[Int|Float] for kernel buffer" + +let access_qual = function + | Some Mut -> "" (* read+write *) + | _ -> "const " (* read-only *) + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> id.name + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> unsupported "concat not supported" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (OpNeg, x) -> "(-" ^ gen_expr x ^ ")" + | ExprUnary (OpNot, x) -> "(!" ^ gen_expr x ^ ")" + | ExprUnary (OpBitNot, x) -> "(~" ^ gen_expr x ^ ")" + | ExprUnary _ -> unsupported "unary op not supported" + | ExprIf { ei_cond; ei_then; ei_else } -> + let f = match ei_else with Some e -> gen_expr e | None -> "0" in + Printf.sprintf "(%s ? %s : %s)" (gen_expr ei_cond) (gen_expr ei_then) f + | ExprIndex (a, i) -> Printf.sprintf "%s[%s]" (gen_expr a) (gen_expr i) + | ExprApp (callee, args) -> + let name = match callee with ExprVar id -> id.name | _ -> unsupported "indirect call" in + let known = ["sin"; "cos"; "tan"; "sqrt"; "exp"; "log"; "pow"; + "fabs"; "floor"; "ceil"; "min"; "max"; "tanh"] in + if not (List.mem name known) then + unsupported ("call to non-builtin in OpenCL kernel: " ^ name); + Printf.sprintf "%s(%s)" name (String.concat ", " (List.map gen_expr args)) + | ExprSpan (inner, _) -> gen_expr inner + | _ -> unsupported "expression form not supported" + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + let s = if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s in + s ^ "f" + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | _ -> unsupported "literal form not supported" + +let rec gen_stmt (s : stmt) : string = + match s with + | StmtLet { sl_pat = PatVar id; sl_value; sl_ty; _ } -> + let ty = match sl_ty with Some t -> scalar_of t | None -> "int" in + Printf.sprintf "%s %s = %s;" ty id.name (gen_expr sl_value) + | StmtLet _ -> unsupported "destructuring let" + | StmtAssign (lhs, op, rhs) -> + let s = match op with + | AssignEq -> "=" | AssignAdd -> "+=" | AssignSub -> "-=" + | AssignMul -> "*=" | AssignDiv -> "/=" in + Printf.sprintf "%s %s %s;" (gen_expr lhs) s (gen_expr rhs) + | StmtExpr e -> gen_expr e ^ ";" + | StmtWhile (c, b) -> + Printf.sprintf "while (%s) { %s }" (gen_expr c) + (String.concat " " (List.map gen_stmt b.blk_stmts)) + | StmtFor _ -> unsupported "for-in" + +let pick_kernel (program : program) : fn_decl = + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) program.prog_decls in + match List.find_opt (fun fd -> fd.fd_name.name = "kernel") fns with + | Some fd -> fd + | None -> match fns with fd :: _ -> fd | [] -> unsupported "no function found" + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler (OpenCL C)\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + let fd = pick_kernel program in + let idx = match fd.fd_params with first :: _ -> first.p_name.name | _ -> "i" in + let bufs = match fd.fd_params with _ :: rest -> rest | [] -> [] in + let buf_decls = List.map (fun (p : param) -> + Printf.sprintf "__global %s%s *%s" + (access_qual p.p_ownership) (array_element p.p_ty) p.p_name.name + ) bufs in + Buffer.add_string buf + (Printf.sprintf "__kernel void %s(%s) {\n" fd.fd_name.name + (String.concat ", " buf_decls)); + Buffer.add_string buf + (Printf.sprintf " int %s = (int)get_global_id(0);\n" idx); + (match fd.fd_body with + | FnExpr e -> + Buffer.add_string buf (Printf.sprintf " (void)(%s);\n" (gen_expr e)) + | FnBlock b -> + List.iter (fun s -> Buffer.add_string buf (" " ^ gen_stmt s ^ "\n")) b.blk_stmts; + (match b.blk_expr with + | Some e -> Buffer.add_string buf (Printf.sprintf " (void)(%s);\n" (gen_expr e)) + | None -> ())); + Buffer.add_string buf "}\n"; + Buffer.contents buf + +let codegen_opencl (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Cl_unsupported m -> Error ("OpenCL backend: " ^ m) + | Failure m -> Error ("OpenCL codegen error: " ^ m) + | e -> Error ("OpenCL codegen error: " ^ Printexc.to_string e) diff --git a/lib/protobuf.ml b/lib/protobuf.ml new file mode 100644 index 0000000..690cef3 --- /dev/null +++ b/lib/protobuf.ml @@ -0,0 +1,142 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Minimal Protocol Buffers wire-format encoder. + + Implements the subset needed to write ONNX [.onnx] files (and any other + proto2/proto3 message that uses the same wire types): + + - varint (wire type 0) — for ints, enums, and lengths + - 64-bit fixed (wire type 1) — for [double], [fixed64], [sfixed64] + - length-delimited (wire type 2) — for strings, bytes, embedded messages, + and packed-repeated scalars + - 32-bit fixed (wire type 5) — for [float], [fixed32], [sfixed32] + + Wire tag layout (per Google's spec): [(field_number << 3) | wire_type], + encoded as a varint. The decoder reads the tag, splits it back into + field number and wire type, and uses the wire type to know how many + bytes follow. + + This module is intentionally schema-agnostic: it emits the bytes for a + single field at a time. Higher layers (e.g. {!Onnx_proto}) compose + these calls to build typed messages. *) + +(* ============================================================================ + Wire types + ============================================================================ *) + +type wire_type = + | Varint (* 0 *) + | Fixed64 (* 1 *) + | LengthDelimited (* 2 *) + | Fixed32 (* 5 *) + +let wire_type_int = function + | Varint -> 0 + | Fixed64 -> 1 + | LengthDelimited -> 2 + | Fixed32 -> 5 + +(* ============================================================================ + Varint encoding + + 7 bits of payload per byte; MSB set on every byte except the last. Negative + ints (two's complement) require 10 bytes; unsigned ints are encoded as the + smallest number of bytes that fits. + ============================================================================ *) + +let encode_varint (buf : Buffer.t) (n : int) : unit = + let rec loop n = + if n < 0x80 then Buffer.add_char buf (Char.chr n) + else begin + Buffer.add_char buf (Char.chr ((n land 0x7F) lor 0x80)); + loop (n lsr 7) + end + in + if n < 0 then + (* Two's complement extension: emit the low 64 bits worth of bytes, + which is at most 10 7-bit groups. *) + let rec loop10 n i = + if i = 9 then Buffer.add_char buf (Char.chr (n land 0x7F)) + else begin + Buffer.add_char buf (Char.chr ((n land 0x7F) lor 0x80)); + loop10 (n lsr 7) (i + 1) + end + in + loop10 n 0 + else + loop n + +(* ============================================================================ + Field tags + + Combine field number and wire type into a single varint. ONNX field + numbers stay below 32, so the tag fits in one byte. + ============================================================================ *) + +let encode_tag (buf : Buffer.t) (field_number : int) (wt : wire_type) : unit = + let tag = (field_number lsl 3) lor (wire_type_int wt) in + encode_varint buf tag + +(* ============================================================================ + Field encoders (one per scalar type) + + Each writes the tag followed by the value in the appropriate wire format. + Empty/zero values are still emitted: the protobuf default-elision rule + applies only to *messages with proto3 default-zero suppression*; for our + purposes (constructing ONNX models from scratch) we always emit fields we + set explicitly, which is what every ONNX writer in the wild does. + ============================================================================ *) + +let encode_int32_field (buf : Buffer.t) (field : int) (n : int) : unit = + encode_tag buf field Varint; + encode_varint buf n + +let encode_int64_field = encode_int32_field +let encode_uint32_field = encode_int32_field + +let encode_float_field (buf : Buffer.t) (field : int) (f : float) : unit = + encode_tag buf field Fixed32; + let bits = Int32.bits_of_float f in + for i = 0 to 3 do + let b = Int32.to_int (Int32.logand (Int32.shift_right_logical bits (i * 8)) 0xFFl) in + Buffer.add_char buf (Char.chr b) + done + +let encode_double_field (buf : Buffer.t) (field : int) (f : float) : unit = + encode_tag buf field Fixed64; + let bits = Int64.bits_of_float f in + for i = 0 to 7 do + let b = Int64.to_int (Int64.logand (Int64.shift_right_logical bits (i * 8)) 0xFFL) in + Buffer.add_char buf (Char.chr b) + done + +let encode_string_field (buf : Buffer.t) (field : int) (s : string) : unit = + encode_tag buf field LengthDelimited; + encode_varint buf (String.length s); + Buffer.add_string buf s + +let encode_bytes_field = encode_string_field + +(** Embed a sub-message: encode its bytes once into a temporary buffer to + learn the length, then emit length-delimited. *) +let encode_message_field (buf : Buffer.t) (field : int) + (encode : Buffer.t -> unit) : unit = + let inner = Buffer.create 64 in + encode inner; + let payload = Buffer.contents inner in + encode_tag buf field LengthDelimited; + encode_varint buf (String.length payload); + Buffer.add_string buf payload + +(** Repeated message field: just emit the same field number multiple times. *) +let encode_repeated_message_field (buf : Buffer.t) (field : int) + (encode_one : Buffer.t -> 'a -> unit) (items : 'a list) : unit = + List.iter (fun item -> + encode_message_field buf field (fun b -> encode_one b item) + ) items + +(** Repeated string field: same convention. *) +let encode_repeated_string_field (buf : Buffer.t) (field : int) + (items : string list) : unit = + List.iter (encode_string_field buf field) items diff --git a/lib/rescript_codegen.ml b/lib/rescript_codegen.ml new file mode 100644 index 0000000..768c3a0 --- /dev/null +++ b/lib/rescript_codegen.ml @@ -0,0 +1,217 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** ReScript emitter (typed-JS path). + + Lowers a subset of AffineScript to ReScript source. ReScript is + syntactically OCaml-flavoured but compiles to JavaScript via the + [rescript] compiler. Semantic mapping is close to OCaml; we emit + ReScript surface syntax (curly-brace functions, [let] expressions, + typed records). *) + +open Ast + +let res_reserved = [ + "and"; "as"; "assert"; "constraint"; "else"; "exception"; "external"; + "false"; "for"; "if"; "in"; "include"; "let"; "list"; "match"; "module"; + "mutable"; "of"; "open"; "private"; "rec"; "switch"; "then"; "to"; "true"; + "try"; "type"; "when"; "while"; "with"; +] + +let mangle s = if List.mem s res_reserved then s ^ "_" else s + +(* Forward decl: ReScript type names (and the mangle_ty rule) live below + gen_function; we mirror OCaml here. *) +let res_mangle_ty s = + if String.length s = 0 then s + else String.uncapitalize_ascii s + +let rec res_type = function + | TyCon id when id.name = "Int" -> "int" + | TyCon id when id.name = "Float" -> "float" + | TyCon id when id.name = "Bool" -> "bool" + | TyCon id when id.name = "String" -> "string" + | TyCon id when id.name = "Unit" -> "unit" + | TyCon id -> res_mangle_ty id.name + | TyTuple [] -> "unit" + | TyTuple ts -> "(" ^ String.concat ", " (List.map res_type ts) ^ ")" + | TyOwn t | TyRef t | TyMut t -> res_type t + | _ -> "_" + +let ret_type = function None -> "unit" | Some t -> res_type t + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + gen_expr callee ^ "(" ^ String.concat ", " (List.map gen_expr args) ^ ")" + | ExprBinary (a, op, b) -> + (* ReScript distinguishes int and float arithmetic operators. We emit + the int form by default; programs whose lhs is Float should use + dedicated builtins, which in turn map to [+.], [-.] etc. For MVP + this matches AS's monomorphic int/float dispatch — the typechecker + already pinned the operator type by this point. *) + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "mod" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpConcat -> "++" + | OpBitAnd -> "land" | OpBitOr -> "lor" | OpBitXor -> "lxor" + | OpShl -> "lsl" | OpShr -> "lsr" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (op, x) -> + (match op with + | OpNeg -> "(- " ^ gen_expr x ^ ")" + | OpNot -> "(!" ^ gen_expr x ^ ")" + | OpBitNot -> "(lnot(" ^ gen_expr x ^ "))" + | OpRef -> "ref(" ^ gen_expr x ^ ")" + | OpDeref -> gen_expr x ^ ".contents") + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(if %s { %s } else { %s })" c t f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let v = gen_expr el_value in + let body = match el_body with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "{ let %s = %s; %s }" var v body + | ExprBlock blk -> gen_block blk + | ExprTuple es -> "(" ^ String.concat ", " (List.map gen_expr es) ^ ")" + | ExprTupleIndex (e, n) -> + (* ReScript projections — match what's in scope. For arity-2 use fst/snd. *) + (match n with + | 0 -> Printf.sprintf "fst(%s)" (gen_expr e) + | 1 -> Printf.sprintf "snd(%s)" (gen_expr e) + | _ -> Printf.sprintf "failwith(\"ReScript: tuple index %d unsupported\")" n) + | ExprArray es -> "[" ^ String.concat ", " (List.map gen_expr es) ^ "]" + | ExprIndex (a, i) -> gen_expr a ^ "[" ^ gen_expr i ^ "]" + | ExprRecord { er_fields; _ } -> + let fs = List.map (fun (id, e_opt) -> + let v = match e_opt with Some e -> gen_expr e | None -> mangle id.name in + Printf.sprintf "%s: %s" (mangle id.name) v + ) er_fields in + "{ " ^ String.concat ", " fs ^ " }" + | ExprField (record, field) -> gen_expr record ^ "." ^ mangle field.name + | ExprVariant (_ty, ctor) -> mangle ctor.name + | ExprMatch { em_scrutinee; em_arms } -> + let arms = List.map (fun arm -> + Printf.sprintf "| %s => %s" (gen_pattern arm.ma_pat) (gen_expr arm.ma_body) + ) em_arms in + Printf.sprintf "switch %s { %s }" (gen_expr em_scrutinee) + (String.concat " " arms) + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | _ -> "(failwith(\"ReScript backend: unsupported expression\"))" + +and gen_pattern (p : pattern) : string = + match p with + | PatWildcard _ -> "_" + | PatVar id -> mangle id.name + | PatLit lit -> gen_lit lit + | PatCon (id, args) -> + if args = [] then mangle id.name + else mangle id.name ^ "(" ^ + String.concat ", " (List.map gen_pattern args) ^ ")" + | PatTuple ps -> "(" ^ String.concat ", " (List.map gen_pattern ps) ^ ")" + | PatRecord (fields, _) -> + let fs = List.map (fun (id, sub) -> + match sub with + | None -> mangle id.name + | Some sub -> Printf.sprintf "%s: %s" (mangle id.name) (gen_pattern sub) + ) fields in + "{ " ^ String.concat ", " fs ^ " }" + | PatAs (id, _) -> mangle id.name + | PatOr (p, _) -> gen_pattern p + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "'" ^ Char.escaped c ^ "'" + | LitUnit _ -> "()" + +and gen_block (blk : block) : string = + let stmts = List.map gen_stmt blk.blk_stmts in + let tail = match blk.blk_expr with Some e -> gen_expr e | None -> "()" in + "{ " ^ String.concat " " stmts ^ " " ^ tail ^ " }" + +and gen_stmt = function + | StmtLet { sl_pat = PatVar id; sl_value; _ } -> + Printf.sprintf "let %s = %s;" (mangle id.name) (gen_expr sl_value) + | StmtLet _ -> "" + | StmtExpr e -> gen_expr e ^ ";" + | StmtAssign (lhs, _, rhs) -> + Printf.sprintf "%s.contents = %s;" (gen_expr lhs) (gen_expr rhs) + | StmtWhile (cond, body) -> + Printf.sprintf "while %s { %s }" (gen_expr cond) (gen_block body) + | StmtFor _ -> "" + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let params = String.concat ", " + (List.map (fun (p : param) -> + Printf.sprintf "%s: %s" (mangle p.p_name.name) (res_type p.p_ty)) + fd.fd_params) in + let ret = ret_type fd.fd_ret_ty in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + Printf.sprintf "let %s = (%s): %s => %s\n\n" name params ret body + +(* ReScript type names, like OCaml's, must be lowercase. *) +let mangle_ty s = + if String.length s = 0 then s + else + let lowered = String.uncapitalize_ascii s in + if List.mem lowered res_reserved then "_" ^ lowered else lowered + +let gen_type_decl (td : type_decl) : string = + let name = mangle_ty td.td_name.name in + match td.td_body with + | TyAlias (TyRecord (fields, _)) -> + let fs = List.map (fun (rf : row_field) -> + Printf.sprintf " %s: %s" (mangle rf.rf_name.name) (res_type rf.rf_ty) + ) fields in + Printf.sprintf "type %s = {\n%s,\n}\n\n" name (String.concat ",\n" fs) + | TyAlias t -> Printf.sprintf "type %s = %s\n\n" name (res_type t) + | TyStruct fields -> + let fs = List.map (fun (sf : struct_field) -> + Printf.sprintf " %s: %s" (mangle sf.sf_name.name) (res_type sf.sf_ty) + ) fields in + Printf.sprintf "type %s = {\n%s,\n}\n\n" name (String.concat ",\n" fs) + | TyEnum variants -> + let vs = List.map (fun (vd : variant_decl) -> + let tys = List.map res_type vd.vd_fields in + let body = if tys = [] then "" else "(" ^ String.concat ", " tys ^ ")" in + Printf.sprintf " | %s%s" (mangle vd.vd_name.name) body + ) variants in + Printf.sprintf "type %s =\n%s\n\n" name (String.concat "\n" vs) + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + List.iter (function + | TopType td -> Buffer.add_string buf (gen_type_decl td) + | _ -> () + ) program.prog_decls; + List.iter (function + | TopFn fd -> Buffer.add_string buf (gen_function fd) + | _ -> () + ) program.prog_decls; + Buffer.contents buf + +let codegen_rescript (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("ReScript codegen error: " ^ msg) + | e -> Error ("ReScript codegen error: " ^ Printexc.to_string e) diff --git a/lib/rust_codegen.ml b/lib/rust_codegen.ml new file mode 100644 index 0000000..3db6e78 --- /dev/null +++ b/lib/rust_codegen.ml @@ -0,0 +1,273 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Rust emitter (MVP). + + Lowers a subset of AffineScript to safe Rust. AS's affine semantics map + naturally to Rust's ownership system, but in MVP we don't lean on that + yet — every Float/Int/Bool is by-value (Copy), strings are [&'static str]. + + This is enough to run [add(40, 2) = 42]-style programs and validate via + [rustc --edition=2021]. *) + +open Ast + +let rust_reserved = [ + "as"; "async"; "await"; "box"; "break"; "const"; "continue"; "crate"; + "do"; "dyn"; "else"; "enum"; "extern"; "false"; "final"; "fn"; "for"; + "if"; "impl"; "in"; "let"; "loop"; "macro"; "match"; "mod"; "move"; + "mut"; "override"; "priv"; "pub"; "ref"; "return"; "self"; "Self"; + "static"; "struct"; "super"; "trait"; "true"; "try"; "type"; "typeof"; + "union"; "unsafe"; "unsized"; "use"; "virtual"; "where"; "while"; "yield"; +] + +let mangle s = if List.mem s rust_reserved then "r#" ^ s else s + +let rec rust_type = function + | TyCon id when id.name = "Int" -> "i64" + | TyCon id when id.name = "Float" -> "f64" + | TyCon id when id.name = "Bool" -> "bool" + | TyCon id when id.name = "String" -> "&'static str" + | TyCon id when id.name = "Unit" -> "()" + | TyCon id -> mangle id.name + | TyTuple [] -> "()" + | TyTuple ts -> "(" ^ String.concat ", " (List.map rust_type ts) ^ ")" + | TyOwn t | TyRef t | TyMut t -> rust_type t + | _ -> "()" + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + gen_expr callee ^ "(" ^ String.concat ", " (List.map gen_expr args) ^ ")" + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> "+" (* approximate *) + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (op, x) -> + (match op with + | OpNeg -> "(-" ^ gen_expr x ^ ")" + | OpNot -> "(!" ^ gen_expr x ^ ")" + | OpBitNot -> "(!" ^ gen_expr x ^ ")" + | OpRef -> "(&" ^ gen_expr x ^ ")" + | OpDeref -> "(*" ^ gen_expr x ^ ")") + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(if %s { %s } else { %s })" c t f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let v = gen_expr el_value in + let body = match el_body with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "{ let %s = %s; %s }" var v body + | ExprBlock blk -> gen_block blk + | ExprTuple es -> "(" ^ String.concat ", " (List.map gen_expr es) ^ ")" + | ExprTupleIndex (e, n) -> Printf.sprintf "%s.%d" (gen_expr e) n + | ExprArray es -> "[" ^ String.concat ", " (List.map gen_expr es) ^ "]" + | ExprIndex (a, i) -> + Printf.sprintf "%s[%s as usize]" (gen_expr a) (gen_expr i) + | ExprRecord { er_fields; er_spread = _ } -> + let fs = List.map (fun (id, e_opt) -> + let v = match e_opt with Some e -> gen_expr e | None -> mangle id.name in + Printf.sprintf "%s: %s" (mangle id.name) v + ) er_fields in + "{ " ^ String.concat ", " fs ^ " }" + (* Rust requires a struct name on the literal. The TopType decl emits + a struct; for body-position records we wrap with `Self {...}` only + inside an `impl`, which we don't use. The bare brace-form is enough + when the destination is a typed `let` whose type is known by Rust. *) + | ExprField (record, field) -> gen_expr record ^ "." ^ mangle field.name + | ExprVariant (ty, ctor) -> + Printf.sprintf "%s::%s" (mangle ty.name) (mangle ctor.name) + | ExprMatch { em_scrutinee; em_arms } -> + let scrut = gen_expr em_scrutinee in + let arms = List.map (fun arm -> + let pat = gen_pattern arm.ma_pat in + let body = gen_expr arm.ma_body in + Printf.sprintf "%s => %s," pat body + ) em_arms in + Printf.sprintf "match %s { %s }" scrut (String.concat " " arms) + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> "return " ^ gen_expr e + | ExprReturn None -> "return" + | _ -> "unimplemented!(\"Rust backend: unsupported expression\")" + +and gen_pattern (p : pattern) : string = + match p with + | PatWildcard _ -> "_" + | PatVar id -> mangle id.name + | PatLit lit -> gen_lit lit + | PatCon (id, args) -> + let arg_pats = List.map gen_pattern args in + if arg_pats = [] then mangle id.name + else mangle id.name ^ "(" ^ String.concat ", " arg_pats ^ ")" + | PatTuple ps -> "(" ^ String.concat ", " (List.map gen_pattern ps) ^ ")" + | PatRecord (fields, _) -> + let fs = List.map (fun (id, sub) -> + match sub with + | None -> mangle id.name + | Some sub -> Printf.sprintf "%s: %s" (mangle id.name) (gen_pattern sub) + ) fields in + "{ " ^ String.concat ", " fs ^ " }" + | PatAs (id, _) -> mangle id.name + | PatOr (p, _) -> gen_pattern p + +and gen_lit = function + | LitInt (n, _) -> Printf.sprintf "%di64" n + | LitFloat (f, _) -> + let s = string_of_float f in + let s = if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s in + s ^ "f64" + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "'" ^ Char.escaped c ^ "'" + | LitUnit _ -> "()" + +and gen_block (blk : block) : string = + let stmts = List.map gen_stmt blk.blk_stmts in + let tail = match blk.blk_expr with Some e -> gen_expr e | None -> "()" in + "{ " ^ String.concat " " stmts ^ " " ^ tail ^ " }" + +and gen_stmt = function + | StmtLet { sl_pat = PatVar id; sl_value; sl_mut; sl_ty; _ } -> + let kw = if sl_mut then "let mut" else "let" in + (* Rust struct literals need a name; if the source said `let p: T = { ... }`, + lift T onto the record literal so it parses. *) + let v = gen_value_with_hint sl_ty sl_value in + Printf.sprintf "%s %s = %s;" kw (mangle id.name) v + | StmtLet _ -> "" + | StmtExpr e -> gen_expr e ^ ";" + | StmtAssign (lhs, op, rhs) -> + let s = match op with + | AssignEq -> "=" | AssignAdd -> "+=" | AssignSub -> "-=" + | AssignMul -> "*=" | AssignDiv -> "/=" in + Printf.sprintf "%s %s %s;" (gen_expr lhs) s (gen_expr rhs) + | StmtWhile (cond, body) -> + Printf.sprintf "while %s %s" (gen_expr cond) (gen_block body) + | StmtFor _ -> "" + +(* When we know the destination type for a record literal (from a let + annotation), prepend the struct name so Rust's parser accepts the + `Name { ... }` form. *) +and gen_value_with_hint (ty_hint : type_expr option) (e : expr) : string = + match e, ty_hint with + | ExprRecord _, Some (TyCon id) -> + mangle id.name ^ " " ^ gen_expr e + | _ -> gen_expr e + +(* Rust requires [fn main()] to return [()] (or Termination). If the AS + source defines a [main], we rename it to [__as_main] and emit a real + Rust [main] that propagates its result as an exit code. *) +let gen_function ?(rename_to = None) (fd : fn_decl) : string = + let name = match rename_to with + | Some n -> n + | None -> mangle fd.fd_name.name + in + let params = String.concat ", " + (List.map (fun (p : param) -> + Printf.sprintf "%s: %s" (mangle p.p_name.name) (rust_type p.p_ty)) + fd.fd_params) in + let ret = match fd.fd_ret_ty with + | None -> "" + | Some t -> " -> " ^ rust_type t + in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + Printf.sprintf "fn %s(%s)%s %s\n\n" name params ret body + +(* Heuristic: a struct can derive Copy if every field's Rust type is Copy. + Our scalar map (i64/f64/bool, plain tuples of those) is uniformly Copy; + referenced types (`void *` fallback) are also Copy at the type level. + AS's affine semantics permit moves through function calls in current + surface code; deriving Copy here keeps the Rust output usable without + forcing every record into clone-ceremony. *) +let rec is_copyable = function + | TyCon id when id.name = "Int" || id.name = "Float" || id.name = "Bool" + || id.name = "Char" || id.name = "Unit" -> true + | TyTuple ts -> List.for_all is_copyable ts + | TyOwn t | TyRef t | TyMut t -> is_copyable t + | _ -> false + +let emit_struct (name : string) (fields : (string * type_expr) list) : string = + let derives = + if List.for_all (fun (_, t) -> is_copyable t) fields + then "#[derive(Copy, Clone, Debug)]" + else "#[derive(Clone, Debug)]" + in + let fs = List.map (fun (n, ty) -> + Printf.sprintf " pub %s: %s," (mangle n) (rust_type ty)) fields in + Printf.sprintf "%s\nstruct %s {\n%s\n}\n\n" + derives name (String.concat "\n" fs) + +let gen_type_decl (td : type_decl) : string = + let name = mangle td.td_name.name in + match td.td_body with + (* `type Foo = { ... }` parses as TyAlias-of-TyRecord. We treat it as a + nominal struct rather than a literal type alias because Rust requires + a name for record-shaped types. *) + | TyAlias (TyRecord (fields, _)) -> + let pairs = List.map (fun (rf : row_field) -> (rf.rf_name.name, rf.rf_ty)) fields in + emit_struct name pairs + | TyAlias t -> Printf.sprintf "type %s = %s;\n\n" name (rust_type t) + | TyStruct fields -> + let pairs = List.map (fun (sf : struct_field) -> (sf.sf_name.name, sf.sf_ty)) fields in + emit_struct name pairs + | TyEnum variants -> + let vs = List.map (fun (vd : variant_decl) -> + let tys = List.map rust_type vd.vd_fields in + let body = if tys = [] then "" else "(" ^ String.concat ", " tys ^ ")" in + Printf.sprintf " %s%s," (mangle vd.vd_name.name) body + ) variants in + Printf.sprintf "#[derive(Clone, Debug)]\nenum %s {\n%s\n}\n\n" + name (String.concat "\n" vs) + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n"; + Buffer.add_string buf "#![allow(unused, dead_code, non_snake_case, unused_parens, non_camel_case_types)]\n\n"; + (* Type decls come first so functions can reference them. *) + List.iter (function + | TopType td -> Buffer.add_string buf (gen_type_decl td) + | _ -> () + ) program.prog_decls; + (* `use ::*` opens variant constructors so user code can write + [Circle(5)] instead of [Shape::Circle(5)] — matches AffineScript's + surface where variant names are resolved without the type prefix. *) + List.iter (function + | TopType { td_name; td_body = TyEnum _; _ } -> + Buffer.add_string buf + (Printf.sprintf "use %s::*;\n" (mangle td_name.name)) + | _ -> () + ) program.prog_decls; + Buffer.add_char buf '\n'; + let has_main = ref false in + List.iter (function + | TopFn fd when fd.fd_name.name = "main" -> + has_main := true; + Buffer.add_string buf (gen_function ~rename_to:(Some "__as_main") fd) + | TopFn fd -> + Buffer.add_string buf (gen_function fd) + | _ -> () + ) program.prog_decls; + if !has_main then + Buffer.add_string buf "fn main() { std::process::exit(__as_main() as i32) }\n"; + Buffer.contents buf + +let codegen_rust (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure msg -> Error ("Rust codegen error: " ^ msg) + | e -> Error ("Rust codegen error: " ^ Printexc.to_string e) diff --git a/lib/spirv_codegen.ml b/lib/spirv_codegen.ml new file mode 100644 index 0000000..a55b362 --- /dev/null +++ b/lib/spirv_codegen.ml @@ -0,0 +1,154 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** SPIR-V binary emitter (MVP, proof of wire format). + + SPIR-V is a 32-bit-word-oriented binary format. Each instruction is a + sequence of words: the first word packs [WordCount << 16 | OpCode], and + operands follow. + + This MVP emits a minimal valid module: capability + memory model + a + single [GLCompute] entry point [@kernel] that takes one [Int] + parameter (the index) and does nothing. It is enough to round-trip + through [spirv-val] and prove the wire format is correct. Real GPU + work is left to Phase 2. + + Reference: https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html +*) + +open Ast + +exception Spirv_unsupported of string +let unsupported m = raise (Spirv_unsupported m) + +(* ============================================================================ + Word emission + + SPIR-V words are little-endian 32-bit unsigned. Strings are NUL-terminated + and zero-padded to 4-byte alignment. + ============================================================================ *) + +let emit_word (buf : Buffer.t) (w : int) : unit = + Buffer.add_char buf (Char.chr (w land 0xFF)); + Buffer.add_char buf (Char.chr ((w lsr 8) land 0xFF)); + Buffer.add_char buf (Char.chr ((w lsr 16) land 0xFF)); + Buffer.add_char buf (Char.chr ((w lsr 24) land 0xFF)) + +let emit_string (buf : Buffer.t) (s : string) : int = + (* Returns the number of words consumed. NUL-terminated, zero-padded. *) + let n = String.length s in + let total = n + 1 in + let padded = ((total + 3) / 4) * 4 in + for i = 0 to padded - 1 do + let c = if i < n then Char.code s.[i] else 0 in + Buffer.add_char buf (Char.chr c) + done; + padded / 4 + +let emit_op (buf : Buffer.t) (opcode : int) (operands : int list) : unit = + let word_count = 1 + List.length operands in + emit_word buf ((word_count lsl 16) lor opcode); + List.iter (emit_word buf) operands + +(* String-bearing ops are emitted manually because their length depends on + the string padding rather than on a fixed operand count. *) +let emit_op_with_string (buf : Buffer.t) (opcode : int) (prefix : int list) + (s : string) : unit = + let str_buf = Buffer.create 64 in + let str_words = emit_string str_buf s in + let word_count = 1 + List.length prefix + str_words in + emit_word buf ((word_count lsl 16) lor opcode); + List.iter (emit_word buf) prefix; + Buffer.add_string buf (Buffer.contents str_buf) + +(* ============================================================================ + Module emission + + Layout (per spec section 2.4): + 1. Magic + version + generator + bound + reserved (header, 5 words) + 2. Capabilities + 3. Extensions + 4. ExtInstImports + 5. MemoryModel + 6. EntryPoints + 7. ExecutionModes + 8. Debug info (Source, OpName, OpString) + 9. Annotations + 10. Type, constant, global declarations + 11. Function declarations + 12. Function definitions + ============================================================================ *) + +let pick_kernel (program : program) : fn_decl = + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) program.prog_decls in + match List.find_opt (fun fd -> fd.fd_name.name = "kernel") fns with + | Some fd -> fd + | None -> match fns with + | fd :: _ -> fd + | [] -> unsupported "no function found" + +let generate (program : program) (_symbols : Symbol.t) : string = + let entry = pick_kernel program in + let entry_name = entry.fd_name.name in + let buf = Buffer.create 512 in + + (* SPIR-V op codes (subset, from the spec) *) + let op_capability = 17 in + let op_memory_model = 14 in + let op_entry_point = 15 in + let op_execution_mode = 16 in + let op_type_void = 19 in + let op_type_function = 33 in + let op_function = 54 in + let op_function_end = 56 in + let op_label = 248 in + let op_return = 253 in + + (* SSA IDs — we hand-allocate just the few we need. *) + let id_void_ty = 1 in + let id_fn_ty = 2 in + let id_main = 3 in + let id_label = 4 in + let bound = 5 in (* one past the highest used ID *) + + (* Header: magic, version 1.0, generator (anything; we use 0), bound, schema *) + emit_word buf 0x07230203; + emit_word buf 0x00010000; (* version 1.0 *) + emit_word buf 0; (* generator magic — 0 means 'unknown', valid *) + emit_word buf bound; + emit_word buf 0; (* schema *) + + (* Capability Shader (1) — required for GLCompute entry points *) + emit_op buf op_capability [1]; + + (* OpMemoryModel Logical (0) GLSL450 (1) *) + emit_op buf op_memory_model [0; 1]; + + (* OpEntryPoint GLCompute (5) %id_main "kernel" *) + emit_op_with_string buf op_entry_point [5; id_main] entry_name; + + (* OpExecutionMode %id_main LocalSize 64 1 1 (mode 17) *) + emit_op buf op_execution_mode [id_main; 17; 64; 1; 1]; + + (* %id_void_ty = OpTypeVoid *) + emit_op buf op_type_void [id_void_ty]; + (* %id_fn_ty = OpTypeFunction %id_void_ty *) + emit_op buf op_type_function [id_fn_ty; id_void_ty]; + + (* %id_main = OpFunction %id_void_ty None %id_fn_ty *) + emit_op buf op_function [id_void_ty; id_main; 0; id_fn_ty]; + (* %id_label = OpLabel *) + emit_op buf op_label [id_label]; + (* OpReturn *) + emit_op buf op_return []; + (* OpFunctionEnd *) + emit_op buf op_function_end []; + + Buffer.contents buf + +let codegen_spirv (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Spirv_unsupported m -> Error ("SPIR-V backend: " ^ m) + | Failure m -> Error ("SPIR-V codegen error: " ^ m) + | e -> Error ("SPIR-V codegen error: " ^ Printexc.to_string e) diff --git a/lib/verilog_codegen.ml b/lib/verilog_codegen.ml new file mode 100644 index 0000000..deebe00 --- /dev/null +++ b/lib/verilog_codegen.ml @@ -0,0 +1,109 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Verilog emitter (combinational MVP). + + Lowers a strict subset of AffineScript to a Verilog module: a single + pure function on integer inputs becomes a combinational module with + an always_comb block. No clock, no state, no buses — that's a Phase 2 + feature. *) + +open Ast + +exception Vlog_unsupported of string +let unsupported m = raise (Vlog_unsupported m) + +let bit_width = 32 (* MVP: every Int is 32-bit signed *) + +let ty_to_decl ty kind = + (match ty with + | TyCon id when id.name = "Int" -> () + | TyCon id when id.name = "Bool" -> () + | _ -> unsupported "Verilog backend accepts only Int/Bool ports/wires"); + let width = match ty with + | TyCon id when id.name = "Bool" -> 1 + | _ -> bit_width + in + if width = 1 then kind ^ " " + else Printf.sprintf "%s signed [%d:0] " kind (width - 1) + +(* ============================================================================ + Pure expression compiler + + Returns a Verilog expression string. No side effects — combinational + logic only. + ============================================================================ *) + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit (LitInt (n, _)) -> string_of_int n + | ExprLit (LitBool (true, _)) -> "1'b1" + | ExprLit (LitBool (false, _))-> "1'b0" + | ExprLit _ -> unsupported "non-numeric literal" + | ExprVar id -> id.name + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> unsupported "string concat not supported in Verilog" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (op, x) -> + (match op with + | OpNeg -> "(-" ^ gen_expr x ^ ")" + | OpNot -> "(!" ^ gen_expr x ^ ")" + | OpBitNot -> "(~" ^ gen_expr x ^ ")" + | _ -> unsupported "unary op not supported in Verilog") + | ExprIf { ei_cond; ei_then; ei_else } -> + let c = gen_expr ei_cond in + let t = gen_expr ei_then in + let f = match ei_else with Some e -> gen_expr e | None -> "0" in + Printf.sprintf "(%s ? %s : %s)" c t f + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | ExprApp _ -> + unsupported "function calls not supported in Verilog MVP \ + (lift body of called function inline first)" + | _ -> unsupported "expression form not supported in Verilog MVP" + +let gen_module (buf : Buffer.t) (fd : fn_decl) : unit = + let name = fd.fd_name.name in + let body_expr = match fd.fd_body with + | FnExpr e -> e + | FnBlock { blk_stmts = []; blk_expr = Some e } -> e + | FnBlock _ -> + unsupported "Verilog backend MVP only handles single-expression bodies" + in + let ports = List.map (fun (p : param) -> + let kind = ty_to_decl p.p_ty "input" in + Printf.sprintf "%s%s" kind p.p_name.name + ) fd.fd_params in + let out_kind = ty_to_decl + (match fd.fd_ret_ty with Some t -> t | None -> TyCon { name="Int"; span=Span.dummy }) + "output" in + let port_list = String.concat ",\n " (ports @ [out_kind ^ "out"]) in + Buffer.add_string buf (Printf.sprintf "module %s (\n %s\n);\n" name port_list); + Buffer.add_string buf + (Printf.sprintf " assign out = %s;\n" (gen_expr body_expr)); + Buffer.add_string buf "endmodule\n\n" + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "// Generated by AffineScript compiler\n"; + Buffer.add_string buf "// SPDX-License-Identifier: PMPL-1.0-or-later\n\n"; + List.iter (function + | TopFn fd -> gen_module buf fd + | _ -> () + ) program.prog_decls; + Buffer.contents buf + +let codegen_verilog (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Vlog_unsupported m -> Error ("Verilog backend: " ^ m) + | Failure m -> Error ("Verilog codegen error: " ^ m) + | e -> Error ("Verilog codegen error: " ^ Printexc.to_string e) diff --git a/lib/wgsl_codegen.ml b/lib/wgsl_codegen.ml new file mode 100644 index 0000000..85003bb --- /dev/null +++ b/lib/wgsl_codegen.ml @@ -0,0 +1,326 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** WGSL Kernel Sublanguage Emitter (MVP). + + Lowers a strict subset of AffineScript to a WebGPU compute shader. + + Source shape (the kernel sublanguage): + - exactly one [fn] declaration is the kernel, named [kernel] or [main] + (or, if neither, the first [fn] in the file) + - the first parameter is [Int] and represents the global invocation index + - remaining parameters are [Array[Int]] or [Array[Float]] and become + WGSL storage buffers; ownership selects access: + - [ref T] -> [var] + - [mut T] -> [var] + - bare T -> [var] + - return type is [Unit] (kernels produce side effects in buffers) + - body uses arithmetic, comparison, [if]/[let]/blocks, [arr[i]] index + reads, and [out[i] = expr] assignments + + Anything outside this subset emits an explicit error rather than + silently miscompiling. The output is a single WGSL file consumable by + any WebGPU host (browser, [wgpu], Dawn, naga-cli). +*) + +open Ast + +(* ============================================================================ + Errors + ============================================================================ *) + +exception Wgsl_unsupported of string +let unsupported msg = raise (Wgsl_unsupported msg) + +(* ============================================================================ + Context + ============================================================================ *) + +type ctx = { + output : Buffer.t; + indent : int; + index_param : string; (* mangled name of the i: Int parameter *) + buffer_tys : (string * string) list; (* (param name, element type "i32"|"f32") *) +} + +let new_ctx () = { + output = Buffer.create 1024; + indent = 0; + index_param = "_gid"; + buffer_tys = []; +} + +let emit ctx s = Buffer.add_string ctx.output s +let emit_line ctx s = + Buffer.add_string ctx.output (String.make (ctx.indent * 2) ' '); + Buffer.add_string ctx.output s; + Buffer.add_char ctx.output '\n' +let inc ctx = { ctx with indent = ctx.indent + 1 } +let dec ctx = { ctx with indent = max 0 (ctx.indent - 1) } + +(* ============================================================================ + Identifier sanitisation + ============================================================================ *) + +let wgsl_reserved = [ + "array"; "atomic"; "bool"; "break"; "case"; "const"; "continue"; "default"; + "discard"; "else"; "enable"; "false"; "fn"; "for"; "if"; "let"; "loop"; + "private"; "ptr"; "return"; "storage"; "struct"; "switch"; "true"; "type"; + "uniform"; "var"; "vec2"; "vec3"; "vec4"; "while"; "workgroup"; + "i32"; "u32"; "f32"; "f16"; + "main"; "kernel"; (* avoid colliding with our entry point *) +] + +let mangle s = + if List.mem s wgsl_reserved then s ^ "_" else s + +(* ============================================================================ + Type lowering + ============================================================================ *) + +let scalar_of_type_name = function + | "Int" -> "i32" + | "Float" -> "f32" + | "Bool" -> "bool" + | other -> unsupported ("type not allowed in WGSL kernel: " ^ other) + +let rec scalar_of (te : type_expr) : string = + match te with + | TyCon id -> scalar_of_type_name id.name + | TyOwn t | TyRef t | TyMut t -> scalar_of t + | _ -> unsupported "complex type not allowed in WGSL kernel" + +let array_element (te : type_expr) : string = + let rec strip = function + | TyOwn t | TyRef t | TyMut t -> strip t + | t -> t + in + match strip te with + | TyApp (id, [TyArg inner]) when id.name = "Array" -> scalar_of inner + | _ -> unsupported "expected Array[Int] or Array[Float] for kernel buffer" + +let access_for_ownership (own : ownership option) : string = + match own with + | Some Mut -> "read_write" + | _ -> "read" + +(* ============================================================================ + Expressions + ============================================================================ *) + +let rec gen_expr ctx (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "==" | OpNe -> "!=" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpBitAnd -> "&" | OpBitOr -> "|" | OpBitXor -> "^" + | OpShl -> "<<" | OpShr -> ">>" + | OpConcat -> unsupported "string/array concat not supported in WGSL" + in + "(" ^ gen_expr ctx a ^ " " ^ s ^ " " ^ gen_expr ctx b ^ ")" + | ExprUnary (op, x) -> + (match op with + | OpNeg -> "(-(" ^ gen_expr ctx x ^ "))" + | OpNot -> "(!" ^ gen_expr ctx x ^ ")" + | OpBitNot -> "(~" ^ gen_expr ctx x ^ ")" + | OpRef | OpDeref -> unsupported "ref/deref not supported in WGSL kernel") + | ExprIf { ei_cond; ei_then; ei_else } -> + (* WGSL has no expression-form if; fold to select() for scalars only. + Block expressions (`if` whose branches are blocks) must appear in + statement position — see gen_stmt. *) + let c = gen_expr ctx ei_cond in + let t = gen_expr ctx ei_then in + let f = match ei_else with + | Some e -> gen_expr ctx e + | None -> unsupported "if without else cannot be an expression in WGSL" + in + Printf.sprintf "select(%s, %s, %s)" f t c + | ExprIndex (arr, idx) -> + Printf.sprintf "%s[u32(%s)]" (gen_expr ctx arr) (gen_expr ctx idx) + | ExprApp (callee, args) -> + (* Permit calls to a small set of WGSL built-ins by name. Anything else + fails — kernels can't call user-defined helper fns in MVP. *) + let name = match callee with + | ExprVar id -> id.name + | _ -> unsupported "indirect calls not supported in WGSL kernel" + in + let known = ["abs"; "min"; "max"; "clamp"; "sqrt"; "floor"; "ceil"; + "round"; "sin"; "cos"; "tan"; "exp"; "log"; "pow"; + "mix"; "step"; "smoothstep"; "f32"; "i32"; "u32"] in + if not (List.mem name known) then + unsupported ("call to non-builtin function in WGSL kernel: " ^ name); + let args_s = List.map (gen_expr ctx) args in + Printf.sprintf "%s(%s)" name (String.concat ", " args_s) + | ExprSpan (inner, _) -> gen_expr ctx inner + | ExprBlock _ -> unsupported "block expression must be in statement position" + | ExprLet _ -> unsupported "let must be a statement, not an expression" + | ExprMatch _ -> unsupported "match not supported in WGSL kernel" + | ExprLambda _ -> unsupported "lambdas not supported in WGSL" + | ExprTuple _ | ExprArray _ | ExprRecord _ + | ExprField _ | ExprTupleIndex _ | ExprRowRestrict _ -> + unsupported "compound values not supported in WGSL kernel (yet)" + | ExprReturn _ | ExprTry _ | ExprHandle _ + | ExprResume _ | ExprUnsafe _ | ExprVariant _ -> + unsupported "control-flow construct not supported in WGSL kernel" + +and gen_lit (lit : literal) : string = + match lit with + | LitInt (n, _) -> Printf.sprintf "%d" n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitChar _ -> unsupported "char literals not supported in WGSL" + | LitString _ -> unsupported "string literals not supported in WGSL" + | LitUnit _ -> unsupported "unit literal in expression position" + +(* Statement-form lowering. Block expressions and let-statements live here. *) +let rec gen_stmt ctx (s : stmt) : unit = + match s with + | StmtLet { sl_pat; sl_value; sl_mut; sl_quantity = _; sl_ty } -> + let var = match sl_pat with + | PatVar id -> mangle id.name + | PatWildcard _ -> "_" + | _ -> unsupported "destructuring let not supported in WGSL" + in + let kw = if sl_mut then "var" else "let" in + let ty_anno = match sl_ty with + | Some t -> ": " ^ scalar_of t + | None -> "" + in + emit_line ctx (Printf.sprintf "%s %s%s = %s;" kw var ty_anno (gen_expr ctx sl_value)) + | StmtExpr e -> + gen_stmt_expr ctx e + | StmtAssign (lhs, op, rhs) -> + let op_str = match op with + | AssignEq -> "=" | AssignAdd -> "+=" + | AssignSub -> "-=" | AssignMul -> "*=" + | AssignDiv -> "/=" + in + emit_line ctx + (Printf.sprintf "%s %s %s;" (gen_expr ctx lhs) op_str (gen_expr ctx rhs)) + | StmtWhile (cond, body) -> + emit_line ctx (Printf.sprintf "while (%s) {" (gen_expr ctx cond)); + gen_block (inc ctx) body; + emit_line ctx "}" + | StmtFor _ -> + unsupported "for-in loop not supported in WGSL kernel (use while)" + +and gen_stmt_expr ctx e = + (* Statement-position expression: emit if/blocks as control flow, scalar + expressions as `_ = expr;` (rare; usually a builtin call). *) + match e with + | ExprIf { ei_cond; ei_then; ei_else } -> + emit_line ctx (Printf.sprintf "if (%s) {" (gen_expr ctx ei_cond)); + gen_branch (inc ctx) ei_then; + (match ei_else with + | Some else_br -> + emit_line ctx "} else {"; + gen_branch (inc ctx) else_br; + emit_line ctx "}" + | None -> + emit_line ctx "}") + | ExprBlock blk -> gen_block ctx blk + | _ -> + emit_line ctx (Printf.sprintf "_ = %s;" (gen_expr ctx e)) + +and gen_branch ctx (e : expr) = + match e with + | ExprBlock blk -> gen_block ctx blk + | _ -> emit_line ctx (gen_expr ctx e ^ ";") + +and gen_block ctx (blk : block) = + List.iter (gen_stmt ctx) blk.blk_stmts; + (match blk.blk_expr with + | Some e -> gen_stmt_expr ctx e + | None -> ()) + +(* ============================================================================ + Top-level: pick the kernel function and emit it + ============================================================================ *) + +let pick_kernel (program : program) : fn_decl = + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) + program.prog_decls + in + match List.find_opt (fun fd -> fd.fd_name.name = "kernel") fns with + | Some fd -> fd + | None -> + match List.find_opt (fun fd -> fd.fd_name.name = "main") fns with + | Some fd -> fd + | None -> + match fns with + | fd :: _ -> fd + | [] -> unsupported "no function found to lower as kernel" + +let validate_kernel (fd : fn_decl) : unit = + (match fd.fd_ret_ty with + | None -> () + | Some (TyCon id) when id.name = "Unit" -> () + | Some (TyTuple []) -> () (* `() ` parses as TyTuple [], synonymous with Unit *) + | _ -> unsupported "kernel function must return Unit or ()"); + match fd.fd_params with + | [] -> unsupported "kernel must take at least an Int index parameter" + | first :: _ -> + (match first.p_ty with + | TyCon id when id.name = "Int" -> () + | _ -> unsupported "first kernel parameter must be Int (the global index)") + +let emit_buffer_bindings ctx (params : param list) : ctx = + (* Skip the first param (the index); the rest become storage buffers. *) + let rec go i ctx = function + | [] -> ctx + | (p : param) :: rest -> + let elem = array_element p.p_ty in + let access = access_for_ownership p.p_ownership in + let name = mangle p.p_name.name in + emit_line ctx + (Printf.sprintf "@group(0) @binding(%d) var %s : array<%s>;" + i access name elem); + go (i + 1) { ctx with buffer_tys = (name, elem) :: ctx.buffer_tys } rest + in + match params with + | [] | [_] -> ctx + | _ :: bufs -> go 0 ctx bufs + +let generate (program : program) (_symbols : Symbol.t) : string = + let ctx = new_ctx () in + emit_line ctx "// Generated by AffineScript compiler (WGSL kernel sublanguage)"; + emit_line ctx "// SPDX-License-Identifier: PMPL-1.0-or-later"; + emit ctx "\n"; + let fd = pick_kernel program in + validate_kernel fd; + let ctx = emit_buffer_bindings ctx fd.fd_params in + let idx_name = match fd.fd_params with + | first :: _ -> mangle first.p_name.name + | _ -> "i" + in + let ctx = { ctx with index_param = idx_name } in + + emit ctx "\n"; + emit_line ctx "@compute @workgroup_size(64)"; + emit_line ctx + (Printf.sprintf "fn %s(@builtin(global_invocation_id) gid : vec3) {" + (mangle fd.fd_name.name)); + let body_ctx = inc ctx in + emit_line body_ctx (Printf.sprintf "let %s : i32 = i32(gid.x);" idx_name); + (match fd.fd_body with + | FnExpr e -> + gen_stmt_expr body_ctx e + | FnBlock blk -> + gen_block body_ctx blk); + emit_line ctx "}"; + Buffer.contents ctx.output + +let codegen_wgsl (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Wgsl_unsupported msg -> Error ("WGSL backend: " ^ msg) + | Failure msg -> Error ("WGSL codegen error: " ^ msg) + | e -> Error ("WGSL codegen error: " ^ Printexc.to_string e) diff --git a/lib/why3_codegen.ml b/lib/why3_codegen.ml new file mode 100644 index 0000000..7b9e715 --- /dev/null +++ b/lib/why3_codegen.ml @@ -0,0 +1,126 @@ +(* SPDX-License-Identifier: PMPL-1.0-or-later *) +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) + +(** Why3 / WhyML emitter (MVP, formal verification target). + + Lowers Int/Bool functions to WhyML inside a single theory module. + Output is consumable by [why3 prove] / [why3 ide] for proof + obligation generation against Z3 / Alt-Ergo / CVC. *) + +open Ast + +let why3_reserved = [ + "abstract"; "absurd"; "alias"; "any"; "as"; "assert"; "assume"; "at"; + "axiom"; "begin"; "break"; "by"; "check"; "clone"; "coinductive"; + "constant"; "continue"; "diverges"; "do"; "done"; "downto"; "else"; + "end"; "ensures"; "epsilon"; "exception"; "exists"; "export"; "false"; + "float"; "for"; "forall"; "fun"; "function"; "ghost"; "goal"; "if"; + "import"; "in"; "inductive"; "invariant"; "label"; "lemma"; "let"; + "match"; "meta"; "module"; "mutable"; "not"; "old"; "predicate"; "private"; + "pure"; "raise"; "raises"; "range"; "reads"; "rec"; "ref"; "requires"; + "return"; "returns"; "scope"; "so"; "then"; "theory"; "to"; "true"; "try"; + "type"; "use"; "val"; "variant"; "while"; "with"; "writes"; +] + +let mangle s = if List.mem s why3_reserved then s ^ "_" else s + +let rec why3_type = function + | TyCon id when id.name = "Int" -> "int" + | TyCon id when id.name = "Bool" -> "bool" + | TyCon id when id.name = "Float" -> "real" + | TyCon id when id.name = "Unit" -> "unit" + | TyCon id -> mangle id.name + | TyOwn t | TyRef t | TyMut t -> why3_type t + | _ -> "int" + +let ret_type = function None -> "unit" | Some t -> why3_type t + +let rec gen_expr (e : expr) : string = + match e with + | ExprLit lit -> gen_lit lit + | ExprVar id -> mangle id.name + | ExprApp (callee, args) -> + let f = gen_expr callee in + let xs = List.map (fun a -> "(" ^ gen_expr a ^ ")") args in + f ^ " " ^ String.concat " " xs + | ExprBinary (a, op, b) -> + let s = match op with + | OpAdd -> "+" | OpSub -> "-" | OpMul -> "*" | OpDiv -> "/" | OpMod -> "%" + | OpEq -> "=" | OpNe -> "<>" + | OpLt -> "<" | OpLe -> "<=" | OpGt -> ">" | OpGe -> ">=" + | OpAnd -> "&&" | OpOr -> "||" + | OpConcat -> "++" + | _ -> "+" + in + "(" ^ gen_expr a ^ " " ^ s ^ " " ^ gen_expr b ^ ")" + | ExprUnary (OpNeg, x) -> "(- " ^ gen_expr x ^ ")" + | ExprUnary (OpNot, x) -> "(not " ^ gen_expr x ^ ")" + | ExprUnary _ -> "0" + | ExprIf { ei_cond; ei_then; ei_else } -> + let f = match ei_else with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(if %s then %s else %s)" (gen_expr ei_cond) (gen_expr ei_then) f + | ExprLet { el_pat; el_value; el_body; _ } -> + let var = match el_pat with PatVar id -> mangle id.name | _ -> "_" in + let body = match el_body with Some e -> gen_expr e | None -> "()" in + Printf.sprintf "(let %s = %s in %s)" var (gen_expr el_value) body + | ExprBlock blk -> gen_block blk + | ExprSpan (inner, _) -> gen_expr inner + | ExprReturn (Some e) -> gen_expr e + | _ -> "0" + +and gen_lit = function + | LitInt (n, _) -> string_of_int n + | LitFloat (f, _) -> + let s = string_of_float f in + if String.length s > 0 && s.[String.length s - 1] = '.' then s ^ "0" else s + | LitBool (true, _) -> "true" + | LitBool (false, _) -> "false" + | LitString (s, _) -> "\"" ^ String.escaped s ^ "\"" + | LitChar (c, _) -> "'" ^ Char.escaped c ^ "'" + | LitUnit _ -> "()" + +and gen_block (blk : block) : string = + let rec fold = function + | [] -> + (match blk.blk_expr with Some e -> gen_expr e | None -> "()") + | StmtLet { sl_pat = PatVar id; sl_value; _ } :: rest -> + Printf.sprintf "(let %s = %s in %s)" (mangle id.name) (gen_expr sl_value) (fold rest) + | _ :: rest -> fold rest + in + fold blk.blk_stmts + +let gen_function (fd : fn_decl) : string = + let name = mangle fd.fd_name.name in + let params = match fd.fd_params with + | [] -> "()" + | _ -> String.concat " " (List.map (fun (p : param) -> + Printf.sprintf "(%s : %s)" (mangle p.p_name.name) (why3_type p.p_ty)) + fd.fd_params) + in + let ret = ret_type fd.fd_ret_ty in + let body = match fd.fd_body with + | FnExpr e -> gen_expr e + | FnBlock b -> gen_block b + in + Printf.sprintf " let %s %s : %s =\n %s\n\n" name params ret body + +let generate (program : program) (_symbols : Symbol.t) : string = + let buf = Buffer.create 1024 in + Buffer.add_string buf "(* Generated by AffineScript compiler (Why3 / WhyML) *)\n"; + Buffer.add_string buf "(* SPDX-License-Identifier: PMPL-1.0-or-later *)\n\n"; + Buffer.add_string buf "module Generated\n"; + Buffer.add_string buf " use int.Int\n"; + Buffer.add_string buf " use bool.Bool\n"; + Buffer.add_string buf " use real.Real\n\n"; + List.iter (function + | TopFn fd -> Buffer.add_string buf (gen_function fd) + | _ -> () + ) program.prog_decls; + Buffer.add_string buf "end\n"; + Buffer.contents buf + +let codegen_why3 (program : program) (symbols : Symbol.t) : (string, string) result = + try Ok (generate program symbols) + with + | Failure m -> Error ("Why3 codegen error: " ^ m) + | e -> Error ("Why3 codegen error: " ^ Printexc.to_string e)