Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 17 additions & 56 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,52 +189,6 @@ bool is_fused_with_others(const vector<vector<Function>> &fused_groups,
return false;
}

// An inliner that can inline an entire set of functions at once. The inliner in
// Inline.h only handles with one function at a time.
class Inliner : public IRMutator {
public:
std::set<Function, Function::Compare> to_inline;

Expr do_inlining(const Expr &e) {
return common_subexpression_elimination(mutate(e));
}

protected:
std::map<Function, std::map<int, Expr>, Function::Compare> qualified_bodies;

Expr get_qualified_body(const Function &f, int idx) {
auto it = qualified_bodies.find(f);
if (it != qualified_bodies.end()) {
auto it2 = it->second.find(idx);
if (it2 != it->second.end()) {
return it2->second;
}
}
Expr e = qualify(f.name() + ".", f.values()[idx]);
e = do_inlining(e);
qualified_bodies[f][idx] = e;
return e;
}

Expr visit(const Call *op) override {
if (op->func.defined()) {
Function f(op->func);
if (to_inline.count(f)) {
auto args = mutate(op->args);
Expr body = get_qualified_body(f, op->value_index);
const vector<string> &func_args = f.args();
for (size_t i = 0; i < args.size(); i++) {
body = Let::make(f.name() + "." + func_args[i], args[i], body);
}
return body;
}
}
return IRMutator::visit(op);
}

using IRMutator::visit;
};

class BoundsInference : public IRMutator {
public:
const vector<Function> &funcs;
Expand Down Expand Up @@ -686,7 +640,7 @@ class BoundsInference : public IRMutator {
vector<pair<Expr, int>> buffers_to_annotate;
for (const auto &arg : args) {
if (arg.is_expr()) {
bounds_inference_args.push_back(inliner->do_inlining(arg.expr));
bounds_inference_args.push_back((*inliner)(arg.expr));
} else if (arg.is_func()) {
Function input(arg.func);
for (int k = 0; k < input.outputs(); k++) {
Expand Down Expand Up @@ -849,16 +803,23 @@ class BoundsInference : public IRMutator {
// Compute the intrinsic relationships between the stages of
// the functions.

// Figure out which functions will be inlined away
// Figure out which functions will be inlined away.
vector<bool> inlined(f.size());
for (size_t i = 0; i < inlined.size(); i++) {
if (i < f.size() - 1 &&
f[i].schedule().compute_level().is_inlined() &&
f[i].can_be_inlined()) {
inlined[i] = true;
inliner.to_inline.insert(f[i]);
} else {
inlined[i] = false;
inlined[i] = (i < f.size() - 1 &&
f[i].schedule().compute_level().is_inlined() &&
f[i].can_be_inlined());
}
// Register them with the Inliner in consumer-first order. f is in
// realization (producer-first) order, so we iterate backwards: the
// outermost consumer of each chain is added first, the bottom
// producer last. The Inliner's iterative-deepening loop processes
// entries in add() order, so consumers go first -- their materialized
// bodies expose Calls to producers, which the later (deeper) passes
// then substitute. See Inliner's class doc for the full picture.
for (size_t i = inlined.size(); i > 0; i--) {
if (inlined[i - 1]) {
inliner.add(f[i - 1]);
}
}

Expand Down Expand Up @@ -893,7 +854,7 @@ class BoundsInference : public IRMutator {
for (auto &s : stages) {
for (auto &cond_val : s.exprs) {
internal_assert(cond_val.value.defined());
cond_val.value = inliner.do_inlining(cond_val.value);
cond_val.value = inliner(cond_val.value);
}
}

Expand Down
Loading
Loading