Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 131 additions & 27 deletions lib/codegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ let rec find_free_vars (bound_vars : string list) (expr : expr) : string list =
List.concat (List.map (find_free_vars bound_vars) args)
| ExprBlock blk ->
(* Statements may introduce bindings *)
let (_, free) = List.fold_left (fun (bound, acc_free) stmt ->
let (bound_after, free) = List.fold_left (fun (bound, acc_free) stmt ->
match stmt with
| StmtLet sl ->
let rhs_free = find_free_vars bound sl.sl_value in
Expand All @@ -320,8 +320,13 @@ let rec find_free_vars (bound_vars : string list) (expr : expr) : string list =
(bound, acc_free @ find_free_vars bound e)
| _ -> (bound, acc_free)
) (bound_vars, []) blk.blk_stmts in
(* The tail expression is in scope of the block's own `let`
bindings, so its free vars must exclude them — use the
threaded [bound_after], not the original [bound_vars]. (Prior
code used [bound_vars], spuriously reporting block-local
binders as free; surfaced by #225 PR3c chained continuations.) *)
let expr_free = match blk.blk_expr with
| Some e -> find_free_vars bound_vars e
| Some e -> find_free_vars bound_after e
| None -> []
in
free @ expr_free
Expand Down Expand Up @@ -417,6 +422,21 @@ let gen_unop (op : unary_op) : instr result =
| OpRef -> Error (UnsupportedFeature "OpRef handled in ExprUnary")
| OpDeref -> Error (UnsupportedFeature "OpDeref handled in ExprUnary")

(** ADR-013 #225 PR3c — recursive CPS hook. The async-boundary transform
([detect_async_base_case] + [gen_async_base_case]) is defined below
[gen_expr] but must be reachable from *inside* the continuation
lambda's body generation so that a continuation which is itself an
async boundary is transformed too (Async→Async chaining). A forward
reference, populated once at module init, breaks the definition-order
cycle without relocating the whole transform into the rec group.
Returns [Some result] when [expr] matched the async shape (and
`thenableThen` is importable), else [None] ⇒ caller lowers normally.
Recursion terminates: each application peels exactly one async
boundary off a finite, strictly-smaller continuation. *)
let async_transform_hook
: (context -> expr -> (context * instr list) result option) ref
= ref (fun _ _ -> None)

(** Generate code for an expression, returning instructions and updated context *)
let rec gen_expr (ctx : context) (expr : expr) : (context * instr list) result =
match expr with
Expand Down Expand Up @@ -567,8 +587,14 @@ let rec gen_expr (ctx : context) (expr : expr) : (context * instr list) result =
(ctx, [I32Const 0l])
in

(* Create fresh context for lambda function *)
let lambda_ctx = { ctx_after_env with locals = []; next_local = 0; loop_depth = 0 } in
(* Create fresh context for lambda function. [next_lambda_id] is
advanced *before* body generation so a nested lambda created
while lowering this body (e.g. a chained CPS continuation, #225
PR3c) gets a distinct id rather than re-using [lambda_id]. *)
let lambda_ctx =
{ ctx_after_env with
locals = []; next_local = 0; loop_depth = 0;
next_lambda_id = lambda_id + 1 } in

(* Environment is always first parameter (even if unused) for uniform calling convention *)
let (ctx_with_env, _) = alloc_local lambda_ctx "__env" in
Expand Down Expand Up @@ -598,7 +624,14 @@ let rec gen_expr (ctx : context) (expr : expr) : (context * instr list) result =
let param_count = env_param_offset + List.length lam.elam_params in

(* Generate lambda body *)
let* (ctx_final, body_code) = gen_expr ctx_with_captured lam.elam_body in
(* #225 PR3c: if the lambda body is itself an async boundary (a
continuation that chains another async call), transform it so
Thenables compose up the chain; otherwise lower normally. *)
let* (ctx_final, body_code) =
match !async_transform_hook ctx_with_captured lam.elam_body with
| Some r -> r
| None -> gen_expr ctx_with_captured lam.elam_body
in

(* Compute additional locals (beyond parameters and captured vars) *)
let local_count = ctx_final.next_local - param_count in
Expand All @@ -613,9 +646,36 @@ let rec gen_expr (ctx : context) (expr : expr) : (context * instr list) result =
let result_type = [I32] in
let func_type = { ft_params = param_types; ft_results = result_type } in

(* Thread the POST-body module-level state forward while keeping the
enclosing scope's local state. The body may have mutated module
accumulators (a nested lambda + its types/globals/datas, #225
PR3c chaining); rebuilding from [ctx_after_env] would silently
drop them. Enclosing locals/next_local/loop_depth/field_layouts
stay from [ctx_after_env] (the lambda's inner locals must not
leak out). For a non-nested body these module fields equal
[ctx_after_env]'s, so behaviour is unchanged. *)
let ctx_post = { ctx_after_env with
types = ctx_final.types;
funcs = ctx_final.funcs;
exports = ctx_final.exports;
imports = ctx_final.imports;
globals = ctx_final.globals;
func_indices = ctx_final.func_indices;
lambda_funcs = ctx_final.lambda_funcs;
next_lambda_id = ctx_final.next_lambda_id;
heap_ptr = ctx_final.heap_ptr;
struct_layouts = ctx_final.struct_layouts;
fn_ret_structs = ctx_final.fn_ret_structs;
variant_tags = ctx_final.variant_tags;
string_data = ctx_final.string_data;
next_string_offset = ctx_final.next_string_offset;
datas = ctx_final.datas;
ownership_annots = ctx_final.ownership_annots;
} in

(* Add type to types list *)
let type_idx = List.length ctx_after_env.types in
let ctx_with_type = { ctx_after_env with types = ctx_after_env.types @ [func_type] } in
let type_idx = List.length ctx_post.types in
let ctx_with_type = { ctx_post with types = ctx_post.types @ [func_type] } in

(* Create lambda function *)
let lambda_func = {
Expand All @@ -624,11 +684,19 @@ let rec gen_expr (ctx : context) (expr : expr) : (context * instr list) result =
f_body = load_captured_code @ body_code;
} in

(* Add lambda function to lifted functions *)
(* The closure's stored function id MUST be this lambda's index in
the final [lambda_funcs] list, because the element segment maps
table slot i -> the i-th lambda's wasm func (see gen_module), and
wrapHandler dispatches via `table.get(fnId)`. The pre-reserved
[lambda_id] equals the list position ONLY for non-nested lambdas
(id order == append order). A nested lambda (e.g. a chained CPS
continuation, #225 PR3c) is appended *before* its enclosing
lambda yet has a *higher* id, so id ≠ position there. Use the
append position (= current list length) instead. *)
let lambda_slot = List.length ctx_with_type.lambda_funcs in
let ctx_with_lambda = {
ctx_with_type with
lambda_funcs = ctx_with_type.lambda_funcs @ [lambda_func];
next_lambda_id = lambda_id + 1;
} in

(* Return a closure: (function_id, env_pointer) as a 2-element tuple *)
Expand All @@ -647,7 +715,7 @@ let rec gen_expr (ctx : context) (expr : expr) : (context * instr list) result =
let closure_code = closure_alloc @ [LocalSet closure_idx] @ [
(* Store function ID at offset 0 *)
LocalGet closure_idx;
I32Const (Int32.of_int lambda_id);
I32Const (Int32.of_int lambda_slot);
I32Store (2, 0);
] @ [
(* Store environment pointer at offset 4 *)
Expand Down Expand Up @@ -1819,18 +1887,27 @@ let mentions_async_prim (e : expr) : bool =
are captured into the continuation env by the proven #199
ExprLambda path (which already marshals N captures). [pre] is
restricted to simple `let`s so the captured set is well-defined.

Single boundary only: no recognised async primitive may appear in
[pre] values or in [cont] (chaining = PR3c). Capture soundness: every
free name in [cont] must be the binder, a param, a top-level
func/const/global, or a [pre]-bound local — anything else ⇒ [None]
(fall back to the unchanged synchronous lowering, zero regression).
The affine/linear single-use obligation (ADR-013 obl. 1) is
discharged by composition: borrow-check runs on this straight-line
AST before the transform, and the once-resumption trap guarantees
[cont] executes exactly once — so no new static machinery here. *)
let detect_async_base_case ~(globals : string list) (params : string list)
(body : expr)
- PR3c: [cont] may itself be an async boundary. It is NOT rejected
here; the recursive [async_transform_hook] re-applies this
recogniser to the continuation lambda's body, so a chain
`let a = async(); let b = async(); …` lowers to nested
continuations whose Thenables compose up the call chain.

[extra] is the set of live-local names available for #199 capture at
the call site (the enclosing context's locals — including, for a
recursively-transformed continuation, the *outer* binder and outer
captured locals). Capture soundness: every free name in [cont] must
be the binder, a param, a top-level func/const/global, a [pre]-bound
local, or one of [extra] — anything else ⇒ [None] (fall back to the
unchanged synchronous lowering, zero regression). A recognised async
primitive nested inside a [pre] value (i.e. not in `let`-binding
head position) is an unsupported shape ⇒ [None]. The affine/linear
single-use obligation (ADR-013 obl. 1) is discharged by composition
(PR3b): borrow-check runs on this straight-line AST before the
transform, and the once-resumption trap guarantees each continuation
executes exactly once — no new static machinery here. *)
let detect_async_base_case ~(globals : string list) ~(params : string list)
~(extra : string list) (body : expr)
: (stmt list * string * expr * expr) option =
let rec unwrap = function
| ExprBlock { blk_stmts = []; blk_expr = Some e } -> unwrap e
Expand All @@ -1848,13 +1925,14 @@ let detect_async_base_case ~(globals : string list) (params : string list)
if List.exists (fun n -> n = None) pre_names then None
else
let pre_bound = List.filter_map (fun x -> x) pre_names in
(* Single boundary: no async primitive in pre values or cont. *)
(* A nested async primitive in a pre value is unsupported (the
boundary must be the `let`-binding head); cont MAY chain. *)
let pre_vals = List.filter_map
(function StmtLet sl -> Some sl.sl_value | _ -> None) pre in
if mentions_async_prim cont
|| List.exists mentions_async_prim pre_vals then None
if List.exists mentions_async_prim pre_vals then None
else
let allowed = binder :: params @ globals @ pre_bound in
let allowed =
binder :: params @ globals @ pre_bound @ extra in
let escaping =
List.filter (fun v -> not (List.mem v allowed))
(dedup (find_free_vars [] cont))
Expand Down Expand Up @@ -1949,6 +2027,31 @@ let gen_async_base_case (ctx : context) (pre : stmt list) (binder : string)
in
Ok (ctx5, body)

(* #225 PR3c: populate the forward reference (declared before [gen_expr])
so the continuation-body site in the [ExprLambda] lowering re-applies
the transform — giving Async→Async chaining where Thenables compose
up the call chain. A continuation has no params (everything it needs
is captured), hence [~params:[]]; [~extra] is the live-local set the
#199 path can capture, which for a recursively-transformed inner
continuation includes the *outer* binder and outer captures. Requires
`thenableThen` to be importable; otherwise [None] ⇒ the ExprLambda
site lowers the body normally (no behaviour change). *)
let () =
async_transform_hook :=
(fun ctx body ->
match List.assoc_opt "thenableThen" ctx.func_indices with
| None -> None
| Some tt ->
begin match detect_async_base_case
~globals:(List.map fst ctx.func_indices)
~params:[]
~extra:(List.map fst ctx.locals)
body with
| Some (pre, binder, call, cont) ->
Some (gen_async_base_case ctx pre binder call cont tt)
| None -> None
end)

let gen_function (ctx : context) (fd : fn_decl) : (context * func) result =
(* Create fresh context for function scope, but preserve lambda_funcs and next_lambda_id *)
let fn_ctx = { ctx with locals = []; next_local = 0; loop_depth = 0 } in
Expand Down Expand Up @@ -1988,7 +2091,8 @@ let gen_function (ctx : context) (fd : fn_decl) : (context * func) result =
if fn_is_async fd then
match detect_async_base_case
~globals:(List.map fst ctx_with_params.func_indices)
(List.map (fun p -> p.p_name.name) fd.fd_params)
~params:(List.map (fun p -> p.p_name.name) fd.fd_params)
~extra:(List.map fst ctx_with_params.locals)
body_expr with
| Some (pre, binder, call, cont) ->
begin match List.assoc_opt "thenableThen"
Expand Down
33 changes: 33 additions & 0 deletions tests/codegen/http_cps_chain.affine
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: PMPL-1.0-or-later
// issue #225 PR 3c — WasmGC CPS transform, Async->Async chaining.
//
// Two sequential async boundaries in one fn: the continuation of the
// first request is ITSELF an async boundary (the second request), and
// the final continuation combines a value derived from the first
// response (`s1`, a prelude local of the inner boundary) with the
// second. The transform recurses (forward-ref async_transform_hook at
// the ExprLambda continuation-body site), so each boundary becomes a
// nested #199 continuation and the Thenables compose up the chain:
//
// launch() -> Thenable(req A)
// .then(outerCont): s1 = status(rA); req B
// -> Thenable(req B)
// .then(innerCont): combine(s1, status(rB))
//
// `s1` is captured across the SECOND split; `rA` is captured by the
// outer continuation; both reach the right continuation via the
// recursively-applied #199 env capture (~extra carries the outer
// binder/locals into the inner detection). Single source surface —
// the author writes straight-line code; the backend threads handles.

use Http::{Thenable, http_request_thenable, thenableThen};

extern fn httpThenableStatus(t: Thenable) -> Int;
extern fn combine(a: Int, b: Int) -> Int;

pub fn launch() -> Int / { Net, Async } {
let rA = http_request_thenable("https://example.test/a", "GET", "");
let s1 = httpThenableStatus(rA);
let rB = http_request_thenable("https://example.test/b", "GET", "");
combine(s1, httpThenableStatus(rB))
}
110 changes: 110 additions & 0 deletions tests/codegen/test_http_cps_chain.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// SPDX-License-Identifier: PMPL-1.0-or-later
// issue #225 PR 3c — wasm e2e for Async->Async chaining.
//
// Two sequential async boundaries. Proves the recursive transform:
// launch() returns synchronously; the host settles request A, which
// re-enters the OUTER continuation; that continuation itself issues
// request B (a second Thenable) and registers the INNER continuation;
// the host settles B and the inner continuation combines a value
// derived from A's response (s1) with B's. Thenables compose up the
// chain — same #199 wrapHandler dispatch + #205 convention as the
// PR2/PR3a tests, just twice.
import assert from 'node:assert/strict';
import { readFile } from 'node:fs/promises';

// /a -> 200, /b -> 201 (distinct so the combine proves ordering).
globalThis.fetch = async (url, init) => ({
status: url.includes('/b') ? 201 : 200,
text: async () => `ok:${init && init.method}`,
});

let inst = null;
const _handles = new Map();
const _results = new Map();
let _next = 1;
const reqUrls = [];
let combineCalls = [];
let combineReturn = null;

function readString(ptr) {
const dv = new DataView(inst.exports.memory.buffer);
const len = dv.getUint32(ptr, true);
const bytes = new Uint8Array(inst.exports.memory.buffer, ptr + 4, len);
return new TextDecoder('utf-8').decode(bytes);
}

function wrapHandler(closurePtr) {
return () => {
const tbl = inst.exports.__indirect_function_table;
const dv = new DataView(inst.exports.memory.buffer);
const fnId = dv.getInt32(closurePtr, true);
const envPtr = dv.getInt32(closurePtr + 4, true);
const fn = tbl.get(fnId);
const args = [envPtr];
while (args.length < fn.length) args.push(0);
return fn(...args);
};
}

const imports = {
wasi_snapshot_preview1: { fd_write: () => 0 },
env: {
httpThenableStatus: (tHandle) => {
const v = _results.get(tHandle);
return v && typeof v.status === 'number' ? v.status : -1;
},
combine: (a, b) => {
combineCalls.push([a, b]);
return a * 1000 + b;
},
},
Http: {
http_request_thenable: (urlPtr, methodPtr, bodyPtr) => {
const url = readString(urlPtr);
const method = readString(methodPtr);
readString(bodyPtr);
reqUrls.push(url);
const h = _next++;
const p = globalThis
.fetch(url, { method })
.then(async (r) => ({ status: r.status, body: await r.text() }))
.catch((e) => ({ __error: String(e) }));
_handles.set(h, p);
return h;
},
thenableThen: (tHandle, onSettlePtr) => {
const cb = wrapHandler(onSettlePtr);
Promise.resolve(_handles.get(tHandle)).then((v) => {
_results.set(tHandle, v);
combineReturn = cb();
});
return 1;
},
},
};

const buf = await readFile('./tests/codegen/http_cps_chain.wasm');
const m = await WebAssembly.instantiate(buf, imports);
inst = m.instance;

const disposable = inst.exports.launch();
assert.ok(Number.isInteger(disposable), 'launch() returns synchronously');
assert.deepEqual(reqUrls, ['https://example.test/a'],
'only request A issued synchronously; B is deferred to the continuation');
assert.equal(combineCalls.length, 0, 'final continuation has not run yet');

// Flush enough microtask/timer rounds for BOTH settlement hops.
for (let i = 0; i < 6; i++) {
await new Promise((res) => setTimeout(res, 0));
await Promise.resolve();
}

assert.deepEqual(reqUrls,
['https://example.test/a', 'https://example.test/b'],
'request B was issued by the outer continuation (chaining)');
assert.deepEqual(combineCalls, [[200, 201]],
'inner continuation combined A-derived s1 (200) with B status (201)');
assert.equal(combineReturn, 200201,
'chained result threaded back: combine(200,201) = 200201');

console.log('test_http_cps_chain.mjs OK');
Loading