diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index dda848e3bc7e..c77ebfeecda7 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -189,52 +189,6 @@ bool is_fused_with_others(const vector> &fused_groups, return false; } -// An inliner that can inline an entire set of functions at once. The inliner in -// Inline.h only handles with one function at a time. -class Inliner : public IRMutator { -public: - std::set to_inline; - - Expr do_inlining(const Expr &e) { - return common_subexpression_elimination(mutate(e)); - } - -protected: - std::map, Function::Compare> qualified_bodies; - - Expr get_qualified_body(const Function &f, int idx) { - auto it = qualified_bodies.find(f); - if (it != qualified_bodies.end()) { - auto it2 = it->second.find(idx); - if (it2 != it->second.end()) { - return it2->second; - } - } - Expr e = qualify(f.name() + ".", f.values()[idx]); - e = do_inlining(e); - qualified_bodies[f][idx] = e; - return e; - } - - Expr visit(const Call *op) override { - if (op->func.defined()) { - Function f(op->func); - if (to_inline.count(f)) { - auto args = mutate(op->args); - Expr body = get_qualified_body(f, op->value_index); - const vector &func_args = f.args(); - for (size_t i = 0; i < args.size(); i++) { - body = Let::make(f.name() + "." + func_args[i], args[i], body); - } - return body; - } - } - return IRMutator::visit(op); - } - - using IRMutator::visit; -}; - class BoundsInference : public IRMutator { public: const vector &funcs; @@ -686,7 +640,7 @@ class BoundsInference : public IRMutator { vector> buffers_to_annotate; for (const auto &arg : args) { if (arg.is_expr()) { - bounds_inference_args.push_back(inliner->do_inlining(arg.expr)); + bounds_inference_args.push_back((*inliner)(arg.expr)); } else if (arg.is_func()) { Function input(arg.func); for (int k = 0; k < input.outputs(); k++) { @@ -849,16 +803,23 @@ class BoundsInference : public IRMutator { // Compute the intrinsic relationships between the stages of // the functions. - // Figure out which functions will be inlined away + // Figure out which functions will be inlined away. vector inlined(f.size()); for (size_t i = 0; i < inlined.size(); i++) { - if (i < f.size() - 1 && - f[i].schedule().compute_level().is_inlined() && - f[i].can_be_inlined()) { - inlined[i] = true; - inliner.to_inline.insert(f[i]); - } else { - inlined[i] = false; + inlined[i] = (i < f.size() - 1 && + f[i].schedule().compute_level().is_inlined() && + f[i].can_be_inlined()); + } + // Register them with the Inliner in consumer-first order. f is in + // realization (producer-first) order, so we iterate backwards: the + // outermost consumer of each chain is added first, the bottom + // producer last. The Inliner's iterative-deepening loop processes + // entries in add() order, so consumers go first -- their materialized + // bodies expose Calls to producers, which the later (deeper) passes + // then substitute. See Inliner's class doc for the full picture. + for (size_t i = inlined.size(); i > 0; i--) { + if (inlined[i - 1]) { + inliner.add(f[i - 1]); } } @@ -893,7 +854,7 @@ class BoundsInference : public IRMutator { for (auto &s : stages) { for (auto &cond_val : s.exprs) { internal_assert(cond_val.value.defined()); - cond_val.value = inliner.do_inlining(cond_val.value); + cond_val.value = inliner(cond_val.value); } } diff --git a/src/Inline.cpp b/src/Inline.cpp index ce829f1c7326..79bb80610e90 100644 --- a/src/Inline.cpp +++ b/src/Inline.cpp @@ -1,12 +1,10 @@ -#include - +#include "Inline.h" #include "CSE.h" #include "Debug.h" #include "ExternFuncArgument.h" #include "IRMutator.h" #include "IROperator.h" #include "IRPrinter.h" -#include "Inline.h" #include "Qualify.h" #include "Substitute.h" @@ -97,103 +95,238 @@ void validate_schedule_inlined_function(Function f) { } } -class Inliner : public IRMutator { - using IRMutator::visit; - - Function func; - - Expr visit(const Call *op) override { - if (op->name == func.name()) { - - // Mutate the args - auto args = mutate(op->args); - - // Grab the body - Expr body = qualify(func.name() + ".", func.values()[op->value_index]); - - const vector func_args = func.args(); - - // Bind the args using Let nodes - internal_assert(args.size() == func_args.size()); - - for (size_t i = 0; i < args.size(); i++) { - if (is_const(args[i]) || args[i].as()) { - body = substitute(func.name() + "." + func_args[i], args[i], body); - } else { - body = Let::make(func.name() + "." + func_args[i], args[i], body); - } - } - - found++; +// --------------------------------------------------------------------------- +// Inliner: design notes +// +// Picking how many functions to inline per CSE invocation is a balance +// between two penalties: +// +// - Doing all N functions in one pass is bad. CSE's RemoveLets, while +// substituting Let bindings away, can re-walk shared subtrees +// exponentially in their nesting depth, and the materialized body at a +// call site after N levels of inlining is a DAG of exactly that shape. +// Per-CSE-invocation cost grows roughly exponentially in the batch size. +// +// - Doing one function per pass is also bad. Each pass walks the entire +// current IR, which grows as bodies get materialized; N passes × +// O(|s|) is quadratic in N. +// +// The exponential bites much harder than the quadratic, so we use a small +// constant batch_size (8). Empirically the optimum drifts roughly like +// log(N) (K≈8 at N=100, K≈12 at N=300), so a constant works well across +// the range we measured. +// +// Implementation: each add() call assigns the entry an order_id (its +// position in the add() sequence). operator() processes the set by +// iterative deepening through that sequence -- a series of passes that +// raise an active_limit by batch_size each time, with visit(Call) only +// inlining entries whose order_id is below the current limit. The CSE +// that runs between passes (per-Provide in the Stmt mutator, top-level +// in the Expr form) flattens shared subtrees into named Let references, +// so the next pass's RemoveLets input has bounded shared-Let nesting. +// +// Correct for any add() order: a Call that survives a pass (because it +// got wrapped inside a body materialized by an earlier limit) is picked +// up by a later pass once its order_id falls under the limit. But add() +// in consumer-first (reverse-topological) order is best for performance: +// each pass's substitutions then expose the next layer of producer Calls +// for the following pass. With the wrong order, the work piles up in the +// final pass once the limit hits the entries the call sites actually +// reference, defeating the bounded-per-pass cost. +// --------------------------------------------------------------------------- + +Inliner::Inliner(const Function &f) { + internal_assert(f.can_be_inlined()) << "Illegal to inline " << f.name() << "\n"; + validate_schedule_inlined_function(f); + add(f); +} - return body; +void Inliner::add(const Function &f) { + for (int i = 0; i < f.outputs(); i++) { + Key k{f.name(), i}; + auto [it, inserted] = to_inline.insert({k, Entry{}}); + if (inserted) { + it->second.func = f; + // order_id is the entry's position in the add() sequence, used + // by operator()'s iterative-deepening loop to decide when this + // entry first becomes eligible to inline. + it->second.order_id = to_inline.size() - 1; + } + } +} - } else { - return IRMutator::visit(op); +Expr Inliner::operator()(const Expr &e) { + if (active_limit != SIZE_MAX || to_inline.size() <= batch_size) { + return common_subexpression_elimination(mutate(e)); + } + Expr result = e; + size_t limit = batch_size; + while (true) { + active_limit = limit; + min_skipped_order_id = SIZE_MAX; + Expr mutated = mutate(result); + active_limit = SIZE_MAX; + // Only run CSE if mutate actually changed anything; a pass that + // only discovered above-the-limit Calls produces an unchanged + // result and there's nothing for CSE to do. + if (!mutated.same_as(result)) { + result = common_subexpression_elimination(mutated); } + // Re-processing of cached bodies in visit(Call) bubbles their + // remaining un-inlined Call into min_skipped_order_id, so this + // truly means "nothing inlinable is left anywhere in the result." + if (min_skipped_order_id == SIZE_MAX) { + break; + } + // Jump directly to the next un-inlined entry instead of stepping + // by batch_size through regions of order-space the input doesn't + // reference. (No need to check limit against to_inline.size(): + // if every entry were below the limit, none could have been + // skipped, so min_skipped_order_id would be SIZE_MAX already.) + limit = min_skipped_order_id + batch_size; } + return result; +} - Expr visit(const Variable *op) override { - if (op->name == func.name() + ".buffer") { - const Call *call = func.is_wrapper(); - internal_assert(call); - // Do a whole-image inline. Substitute the .buffer symbol - // for the wrapped object's .buffer symbol. - string buf_name; - if (call->call_type == Call::Halide) { - buf_name = call->name; - if (Function(call->func).outputs() > 1) { - buf_name += "." + std::to_string(call->value_index); - } - buf_name += ".buffer"; - return Variable::make(type_of(), buf_name); - } else if (call->param.defined()) { - return Variable::make(type_of(), call->name + ".buffer", call->param); - } else { - internal_assert(call->image.defined()); - return Variable::make(type_of(), call->name + ".buffer", call->image); - } - } else { - return op; +Stmt Inliner::operator()(const Stmt &s) { + if (active_limit != SIZE_MAX || to_inline.size() <= batch_size) { + return mutate(s); + } + // Same deepening loop as the Expr version. No top-of-loop CSE here + // because the CSE that bounds per-pass work happens in two places + // already: the per-Provide CSE inside visit(Provide) flattens each + // Provide after this pass's substitutions, and any recursive + // operator()(Expr) from get_qualified_body sees active_limit set and + // CSEs the qualified body it produces. + Stmt result = s; + size_t limit = batch_size; + while (true) { + active_limit = limit; + min_skipped_order_id = SIZE_MAX; + result = mutate(result); + active_limit = SIZE_MAX; + if (min_skipped_order_id == SIZE_MAX) { + break; } + limit = min_skipped_order_id + batch_size; } + return result; +} + +Expr Inliner::visit(const Call *op) { + auto it = to_inline.find({op->name, op->value_index}); + if (it != to_inline.end()) { + // If this entry's order_id is past the current limit, leave the + // Call alone; a later pass with a higher limit will pick it up. + // Remember the smallest such order_id so the outer loop can jump + // the limit directly to it instead of stepping by batch_size. + if (it->second.order_id >= active_limit) { + min_skipped_order_id = std::min(min_skipped_order_id, it->second.order_id); + return IRMutator::visit(op); + } + // Below the limit: actually substitute. + Entry &entry = it->second; + const Function &func = entry.func; + // Mutate the args + auto args = mutate(op->args); + + // Compute (or re-process) the cached qualified body when the + // current active_limit lets us pull more inlinable Calls into it. + if (!entry.qualified_body.defined()) { + entry.qualified_body = qualify(func.name() + ".", func.values()[op->value_index]); + entry.lowest_pending_order_id = 0; + } + if (active_limit > entry.lowest_pending_order_id) { + size_t saved_min = min_skipped_order_id; + min_skipped_order_id = SIZE_MAX; + entry.qualified_body = (*this)(entry.qualified_body); + entry.lowest_pending_order_id = min_skipped_order_id; + min_skipped_order_id = std::min(saved_min, min_skipped_order_id); + } + // Whether or not we re-processed, the cached body still has Calls + // at order_id == lowest_pending un-inlined. Propagate that up so + // the outer deepening loop knows there's more work. + min_skipped_order_id = std::min(min_skipped_order_id, entry.lowest_pending_order_id); + Expr body = entry.qualified_body; + + const vector func_args = func.args(); - Stmt visit(const Provide *op) override { - ScopedValue old_found(found, 0); - Stmt stmt = IRMutator::visit(op); + // Bind the args using Let nodes + internal_assert(args.size() == func_args.size()); - // TODO: making this > 1 should be desirable, - // but explodes compiletimes in some situations. - if (found > 0) { - stmt = common_subexpression_elimination(stmt); + for (size_t i = 0; i < args.size(); i++) { + body = Let::make(func.name() + "." + func_args[i], args[i], body); } - return stmt; + return body; } + return IRMutator::visit(op); +} -public: - int found = 0; +Expr Inliner::visit(const Variable *op) { + // Whole-image inline for wrappers: if op is ".buffer" for some inlined + // wrapper f, rewrite it to the wrapped buffer's Variable. Extract the + // "" prefix and look up directly rather than scanning to_inline. + const string suffix = ".buffer"; + if (op->name.size() <= suffix.size() || + op->name.compare(op->name.size() - suffix.size(), suffix.size(), suffix) != 0) { + return op; + } + // Wrappers always have a single output, so look up by (name, 0). + auto it = to_inline.find({op->name.substr(0, op->name.size() - suffix.size()), 0}); + if (it == to_inline.end()) { + return op; + } + const Function &func = it->second.func; + const Call *call = func.is_wrapper(); + internal_assert(call); + if (call->call_type == Call::Halide) { + string buf_name = call->name; + if (Function(call->func).outputs() > 1) { + buf_name += "." + std::to_string(call->value_index); + } + buf_name += ".buffer"; + return Variable::make(type_of(), buf_name); + } else if (call->param.defined()) { + return Variable::make(type_of(), call->name + ".buffer", call->param); + } else { + internal_assert(call->image.defined()); + return Variable::make(type_of(), call->name + ".buffer", call->image); + } +} - Inliner(const Function &f) - : func(f) { - internal_assert(f.can_be_inlined()) << "Illegal to inline " << f.name() << "\n"; - validate_schedule_inlined_function(f); +Stmt Inliner::visit(const Provide *op) { + Stmt stmt = IRMutator::visit(op); + // CSE on the Provide if it changed -- IRMutator returns op unchanged + // if no child Expr was rewritten, so this skips CSE on Provides where + // no inlining (or wrapper .buffer rewrite) touched anything inside. + // Running CSE on the whole Stmt rather than each value/index separately + // catches shared subexpressions between them. + if (!stmt.same_as(op)) { + stmt = common_subexpression_elimination(stmt); } -}; + return stmt; +} Stmt inline_function(const Stmt &s, const Function &f) { return Inliner(f)(s); } Expr inline_function(Expr e, const Function &f) { - Inliner i(f); - e = i(e); - // TODO: making this > 1 should be desirable, - // but explodes compiletimes in some situations. - if (i.found > 0) { - e = common_subexpression_elimination(e); + return Inliner(f)(e); +} + +Stmt inline_functions(const Stmt &s, const vector &fs) { + if (fs.empty()) { + return s; + } + Inliner i; + for (const Function &f : fs) { + internal_assert(f.can_be_inlined()) << "Illegal to inline " << f.name() << "\n"; + validate_schedule_inlined_function(f); + i.add(f); } - return e; + return i(s); } // Inline all calls to 'f' inside 'caller' diff --git a/src/Inline.h b/src/Inline.h index 344e7c7ddf6d..c84405167e09 100644 --- a/src/Inline.h +++ b/src/Inline.h @@ -5,12 +5,79 @@ * Methods for replacing calls to functions with their definitions. */ +#include +#include +#include + #include "Expr.h" +#include "Function.h" +#include "IRMutator.h" namespace Halide { namespace Internal { -class Function; +/** A mutator that inlines a set of pure functions wherever they're called + * in an IR. Usage: + * + * Inliner inliner; + * for (Function f : to_inline) inliner.add(f); + * Stmt result = inliner(stmt); + * // or: Expr result = inliner(expr); + * + * For a single function, Inliner(f) gives an equivalent shortcut and + * `inline_function(s/e, f)` packages it up further. + * + * For best performance, add() the functions in consumer-first + * (reverse-topological) order: outermost consumers first, innermost + * producers last. Any other order is also correct, just slower. See the + * implementation comments in Inline.cpp for why. */ +class Inliner : public IRMutator { +public: + Inliner() = default; + + /** Construct an Inliner that will inline a single function. */ + explicit Inliner(const Function &f); + + /** Insert f into the set of functions to be inlined. */ + void add(const Function &f); + + /** Inline all calls to the added functions within e/s, returning the + * result with CSE applied. Shadows the inherited inline operator() + * from IRMutator. */ + Expr operator()(const Expr &e); + Stmt operator()(const Stmt &s); + +protected: + Expr visit(const Call *op) override; + Expr visit(const Variable *op) override; + Stmt visit(const Provide *op) override; + + using IRMutator::visit; + +private: + /** Per-(function, value_index) inlining state. */ + struct Entry { + Function func; + Expr qualified_body; + size_t order_id; + /** Min order_id of any inlinable Call still present inside this + * entry's qualified_body. SIZE_MAX means the body has no pending + * inlines (i.e. it's fully inlined, or hasn't been computed yet). + * When active_limit later exceeds this, we recompute the cached + * body so that the freshly-active functions are pulled in too. */ + size_t lowest_pending_order_id = SIZE_MAX; + }; + using Key = std::pair; + std::map to_inline; + + size_t active_limit = SIZE_MAX; + /** Min order_id of any inlinable Call still un-inlined anywhere in + * the working Expr/Stmt this pass. Updated by visit(Call) for Calls + * above the limit and bubbled up from re-processing cached bodies. + * SIZE_MAX means nothing's left to inline; the deepening loop stops. */ + size_t min_skipped_order_id = SIZE_MAX; + static constexpr size_t batch_size = 8; +}; /** Inline a single named function, which must be pure. For a pure function to * be inlined, it must not have any specializations (i.e. it can only have one @@ -21,6 +88,12 @@ Expr inline_function(Expr e, const Function &f); void inline_function(Function caller, const Function &f); // @} +/** Inline a set of pure functions. Equivalent in effect to calling + * inline_function(s, f) for each f, but the shared Inliner lets a chain + * of nested inlines share work via the qualified-body cache; see Inliner + * above for how the batch is processed. */ +Stmt inline_functions(const Stmt &s, const std::vector &fs); + /** Check if the schedule of an inlined function is legal, throwing an error * if it is not. */ void validate_schedule_inlined_function(Function f); diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index a0b4a4c18876..62ec36399a45 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -2576,13 +2576,64 @@ Stmt schedule_functions(const vector &outputs, validate_fused_groups_schedule(fused_groups, env); - for (const auto &group : reverse_view(fused_groups)) { - vector funcs; - vector is_output_list; + // Collect consecutive inlinable groups and apply them in one + // inline_functions pass. We flush the batch before each realization so + // the realization's validate_schedule sees the post-inline 's' (its + // callers, if they were inlined, will have been substituted in by then). + vector pending_inlines; + auto flush_pending_inlines = [&]() { + if (pending_inlines.empty()) { + return; + } + debug(1) << "Inlining group of " << pending_inlines.size() + << " function(s): " << pending_inlines << "\n"; + s = inline_functions(s, pending_inlines); + pending_inlines.clear(); + debug(2) << "Lowering after inlining group of functions:\n" + << s << "\n"; + }; + for (const auto &group : reverse_view(fused_groups)) { + vector group_funcs; + group_funcs.reserve(group.size()); for (const string &name : group) { - Function f = env.find(name)->second; + group_funcs.push_back(env.find(name)->second); + } + + if (group_should_be_inlined(group_funcs)) { + // Inlinable groups have a single pure func. Check the + // schedule-property errors directly here; we can't call + // validate_schedule (which walks 's' for call sites) because + // batched inline chains may not yet have their inner call + // sites exposed in 's'. + const Function &f = group_funcs[0]; + const LoopLevel &store_at = f.schedule().store_level(); + const LoopLevel &hoist_storage_at = f.schedule().hoist_storage_level(); + if (store_at.is_root()) { + user_error << "Func \"" << f.name() << "\" is scheduled store_root(), but is inlined. Funcs that use store_root must also call compute_root or compute_at.\n"; + } else if (!store_at.is_inlined()) { + user_error << "Func \"" << f.name() << "\" is scheduled store_at(), but is inlined. Funcs that use store_at must also call compute_at.\n"; + } + if (hoist_storage_at.is_root()) { + user_error << "Func \"" << f.name() << "\" is scheduled hoist_storage_root(), but is inlined. Funcs that use hoist_storage_root must also call compute_root or compute_at.\n"; + } else if (!hoist_storage_at.is_inlined()) { + user_error << "Func \"" << f.name() << "\" is scheduled hoist_storage(), but is inlined. Funcs that use hoist_storage_root must also call compute_at.\n"; + } + validate_schedule_inlined_function(f); + pending_inlines.push_back(f); + continue; + } + // Realization: flush any pending inlines first so that + // validate_schedule and the InjectFunctionRealization walk see the + // post-inline 's'. In particular, ComputeLegalSchedules inside + // validate_schedule needs the inlined call sites to be visible to + // find this group's funcs. + flush_pending_inlines(); + + vector funcs; + vector is_output_list; + for (const Function &f : group_funcs) { bool is_output = false; for (const Function &o : outputs) { is_output = is_output | o.same_as(f); @@ -2605,18 +2656,13 @@ Stmt schedule_functions(const vector &outputs, continue; } - if (group_should_be_inlined(funcs)) { - debug(1) << "Inlining " << funcs[0].name() << "\n"; - s = inline_function(s, funcs[0]); - } else { - debug(1) << "Injecting realization of " << funcs << "\n"; - InjectFunctionRealization injector(funcs, is_output_list, target, env); - s = injector(s); - internal_assert(injector.found_store_level() && injector.found_compute_level() && injector.found_hoist_storage_level()); - } - + debug(1) << "Injecting realization of " << funcs << "\n"; + InjectFunctionRealization injector(funcs, is_output_list, target, env); + s = injector(s); + internal_assert(injector.found_store_level() && injector.found_compute_level() && injector.found_hoist_storage_level()); debug(2) << s << "\n"; } + flush_pending_inlines(); // We can remove the loop over root now const For *root_loop = s.as(); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index cae811328223..e87452e60b51 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -83,6 +83,7 @@ tests(GROUPS correctness debug_to_file_multiple_outputs.cpp debug_to_file_reorder.cpp decompose_vector_shuffle.cpp + deep_inline_chain.cpp deferred_loop_level.cpp deinterleave4.cpp device_buffer_copies_with_profile.cpp diff --git a/test/correctness/deep_inline_chain.cpp b/test/correctness/deep_inline_chain.cpp new file mode 100644 index 000000000000..7ac21f9c26fe --- /dev/null +++ b/test/correctness/deep_inline_chain.cpp @@ -0,0 +1,82 @@ +#include "Halide.h" + +using namespace Halide; + +// Stress-test for the inliner on a long chain of inline-scheduled Funcs. +// +// The test builds a sequence of Funcs where each new Func calls its previous +// 10 predecessors (at two different values of an extra coordinate) and +// passes the sum through a fresh per-level lookup-table Func. All Funcs are +// left at their default schedule (compute_inline), so every one of them +// must be inlined into the output during lowering. +// +// Before src/Inline.cpp learned to batch and CSE between batches, this was +// exponentially expensive in the chain length: ScheduleFunctions inlined +// one Func at a time, paying O(N) walks of 's', and BoundsInference inlined +// every Func at once into a giant Let-nested DAG that CSE's RemoveLets then +// re-walked exponentially under each Let. With the batched/iterative- +// deepening Inliner, lowering this pipeline takes well under a second. +// +// Failure modes this test guards against: +// - The test hangs or times out in CI: a regression has reintroduced the +// exponential/quadratic lowering cost. +// - The test crashes during JIT compilation or execution: the inliner +// produced malformed IR. +// - The test prints "Mismatch": the inliner produced incorrect IR that +// still lowers and runs, but computes the wrong value. +// +// The mismatch case is checked by computing one output value out-of-band +// (by walking the Func vector at C++ level) and comparing. +int main(int argc, char **argv) { + Var x, c; + + // Reference computation, run in C++ alongside the pipeline build to + // give us an expected output value to compare against. + std::vector> ref; // ref[i][c] == funcs[i](x=0, c) + + std::vector funcs; + auto add_leaf = [&](Func f, int v0, int v1) { + funcs.push_back(std::move(f)); + ref.push_back({v0, v1}); + }; + add_leaf(lambda(x, c, x + c), 0, 1); + add_leaf(lambda(x, c, x + c + 1), 1, 2); + add_leaf(lambda(x, c, x + c + 2), 2, 3); + add_leaf(lambda(x, c, x + c + 3), 3, 4); + + // Number of layers added on top of the four leaf Funcs. Each layer is + // one Func, with a unique LUT also inlined into it. 100 layers is + // plenty to exercise the batched Inliner (the batch size is 8, so this + // produces many sequential batches). + const int N = 100; + + for (int i = 0; i < N; i++) { + Func next, lut; + lut(x) = x * x + i; + Expr e = 0; + int e_ref = 0; + for (int k = 0; k < 10 && k < (int)funcs.size(); k++) { + Func &f = funcs[funcs.size() - 1 - k]; + e += f(x, 0) * f(x, 1); + const auto &fr = ref[ref.size() - 1 - k]; + e_ref += fr[0] * fr[1]; + } + next(x, c) = lut(e); + funcs.push_back(std::move(next)); + // lut(e_ref) = e_ref * e_ref + i. next doesn't depend on c, so + // both c=0 and c=1 give the same value. + int v = e_ref * e_ref + i; + ref.push_back({v, v}); + } + + Buffer out = funcs.back().realize({1, 2}); + int expected = ref.back()[0]; + if (out(0, 0) != expected || out(0, 1) != expected) { + printf("Mismatch: got (%d, %d), expected (%d, %d)\n", + out(0, 0), out(0, 1), expected, expected); + return 1; + } + + printf("Success!\n"); + return 0; +}