blob: 0cc4b9cf3c214f4e2c0e25ecce82a6ef32cfb925 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file verilog_ir.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <utility>
#include "verilog_ir.h"
#include "../../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
namespace verilog {
using namespace ir;
ControlSignal ControlSignalNode::make(
ControlSignalType type, int advance_size) {
auto n = make_node<ControlSignalNode>();
n->ctrl_type = type;
n->advance_size = advance_size;
return ControlSignal(n);
}
StageInput StageInputNode::make(Var var, StageInputType input_type) {
NodePtr<StageInputNode> n = make_node<StageInputNode>();
n->var = var;
n->input_type = input_type;
return StageInput(n);
}
// Replace stage inputs by placeholder, update the input map.
class StageInputReplacer : public IRMutator {
public:
explicit StageInputReplacer(
const std::unordered_map<const Variable*, StageInput>& var_info)
: var_info_(var_info) {}
Expr Mutate_(const Variable* op, const Expr& e) final {
if (replace_.count(op)) {
return replace_.at(op);
}
auto it = var_info_.find(op);
if (it == var_info_.end()) return e;
Var new_var(it->second->var->name_hint + ".sync", op->type);
inputs_.Set(new_var, it->second);
replace_[op] = new_var;
return new_var;
}
Expr Mutate_(const Load* op, const Expr& e) final {
CHECK(is_zero(op->index))
<< "Load should be in its own stage.";
if (replace_.count(op->buffer_var.get())) {
return replace_.at(op->buffer_var.get());
}
auto it = var_info_.find(op->buffer_var.get());
CHECK(it != var_info_.end())
<< "Load from unknown channel";
Var data(it->second->var->name_hint + ".load.sync", op->type);
inputs_.Set(data, it->second);
replace_[op->buffer_var.get()] = data;
return data;
}
// inputs that get replaced.
Map<Var, StageInput> inputs_;
// replacement map
std::unordered_map<const Variable*, Var> replace_;
// Variable replacement plan.
const std::unordered_map<const Variable*, StageInput>& var_info_;
};
/*! \brief Extract module block */
class PipelineExtractor: public IRVisitor {
public:
Pipeline Extract(LoweredFunc f) {
// Initialize the memory map channels
// TODO(tqchen) move the logic to explicit specification.
for (auto arg : f->args) {
if (arg.type().is_handle()) {
arg_handle_[arg.get()] = arg;
}
}
pipeline_ = make_node<PipelineNode>();
this->Visit(f->body);
// setup channels
for (const auto &kv : cmap_) {
pipeline_->channels.Set(
kv.second.node->channel->handle_var,
ChannelBlock(kv.second.node));
}
pipeline_->args = f->args;
return Pipeline(pipeline_);
}
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::pipeline_stage_scope) {
CHECK(!in_pipeline_stage_);
in_pipeline_stage_ = true;
trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op);
trigger_.pop_back();
in_pipeline_stage_ = false;
} else if (op->attr_key == attr::channel_read_advance ||
op->attr_key == attr::channel_write_advance) {
trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op);
trigger_.pop_back();
} else if (op->attr_key == attr::channel_read_scope ||
op->attr_key == attr::channel_write_scope) {
Channel ch(op->node.node_);
ChannelEntry& cb = cmap_[ch->handle_var.get()];
if (cb.node != nullptr) {
CHECK(cb.node->channel.same_as(ch));
} else {
cb.node = make_node<ChannelBlockNode>();
cb.node->channel = ch;
}
if (op->attr_key == attr::channel_read_scope) {
CHECK_EQ(cb.read_ref_count, 0)
<< "One channel can only be read from one consumer";
++cb.read_ref_count;
CHECK(arith::GetConstInt(op->value, &(cb.node->read_window)))
<< "Only supprt constant read window";
} else {
CHECK_EQ(cb.write_ref_count, 0)
<< "One channel can only be write by one producer";
++cb.write_ref_count;
CHECK(arith::GetConstInt(op->value, &(cb.node->write_window)))
<< "Only supprt constant write window";
}
var_info_[ch->handle_var.get()] =
StageInputNode::make(ch->handle_var, kChannel);
IRVisitor::Visit_(op);
var_info_.erase(ch->handle_var.get());
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Block* op) final {
CHECK(!in_pipeline_stage_)
<< "Do not support serial execution inside pipeline";
IRVisitor::Visit_(op);
}
void Visit_(const IfThenElse* op) final {
LOG(FATAL) << "Not implemeneted";
}
void Visit_(const For* op) final {
if (in_pipeline_stage_) {
loop_.push_back(
For::make(op->loop_var, op->min, op->extent,
op->for_type, op->device_api, Evaluate::make(0)));
var_info_[op->loop_var.get()] =
StageInputNode::make(Var(op->loop_var.node_), kLoopVar);
IRVisitor::Visit_(op);
var_info_.erase(op->loop_var.get());
loop_.pop_back();
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Store* op) final {
// Check the access pattern
Channel arg_write =
CheckArgHandleAccess(op->buffer_var.get(), op->value.type(), false);
this->Visit(op->value);
// The replace logic
StageInputReplacer repl(var_info_);
// Setup the compute block.
NodePtr<ComputeBlockNode> compute =
make_node<ComputeBlockNode>();
compute->loop = Array<Stmt>(loop_);
// setup the advance triggers
for (const auto& e : trigger_) {
const AttrStmt* attr = e.second;
Channel ch;
if (attr->attr_key == attr::pipeline_stage_scope) {
ch = arg_write;
if (!ch.defined()) continue;
} else {
ch = Channel(attr->node.node_);
}
NodePtr<SignalTriggerNode> trigger
= make_node<SignalTriggerNode>();
trigger->channel_var = ch->handle_var;
// predicate for the trigger
Expr predicate = const_true();
for (size_t i = e.first; i < loop_.size(); ++i) {
const For* loop = loop_[i].as<For>();
predicate = predicate &&
(loop->loop_var == (loop->extent - 1));
}
trigger->predicate = ir::Simplify(predicate);
// Add the signal back to the channels.
ChannelEntry& cb = cmap_.at(ch->handle_var.get());
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
// Grab the advance constant size.
int trigger_size = 0;
if (attr->attr_key == attr::pipeline_stage_scope) {
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kComputeFinish, 0));
} else if (attr->attr_key == attr::channel_read_advance) {
CHECK(arith::GetConstInt(attr->value, &trigger_size))
<< "Only support constant advance size";
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kReadAdvance, trigger_size));
} else {
CHECK(arith::GetConstInt(attr->value, &trigger_size))
<< "Only support constant advance size";
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kWriteAdvance, trigger_size));
}
compute->triggers.push_back(SignalTrigger(trigger));
}
// Check if we are writing to FIFO.
const Load* load = op->value.as<Load>();
if (is_zero(op->index) && load) {
compute->body = Store::make(
op->buffer_var,
Load::make(load->type, load->buffer_var,
repl.Mutate(load->index), op->predicate),
op->index, op->predicate);
} else {
compute->body = Store::make(
op->buffer_var, repl.Mutate(op->value),
repl.Mutate(op->index), op->predicate);
}
compute->inputs = repl.inputs_;
pipeline_->stages.push_back(ComputeBlock(compute));
}
void Visit_(const LetStmt* op) final {
LOG(FATAL) << "cannot pass through let";
}
void Visit_(const Evaluate* op) final {
LOG(FATAL) << "Not implemeneted";
}
void Visit_(const Allocate* op) final {
CHECK(!in_pipeline_stage_);
}
void Visit_(const AssertStmt* op) final {
LOG(FATAL) << "Not implemeneted";
}
void Visit_(const Load* op) final {
CheckArgHandleAccess(op->buffer_var.get(), op->type, true);
}
Channel CheckArgHandleAccess(const Variable* var, Type dtype, bool read_access) {
if (!arg_handle_.count(var)) return Channel();
CHECK(!cmap_.count(var))
<< "Multiple access to the same handle";
ChannelEntry& cb = cmap_[var];
cb.node = make_node<ChannelBlockNode>();
cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype);
return cb.node->channel;
}
private:
// The channel information.
struct ChannelEntry {
NodePtr<ChannelBlockNode> node;
int read_ref_count{0};
int write_ref_count{0};
};
// Whether we are inside the pipeline stage.
bool in_pipeline_stage_{false};
// The current loop nest
std::vector<Stmt> loop_;
// Advance signal trigger
std::vector<std::pair<size_t, const AttrStmt*> > trigger_;
// Read write scope
std::vector<const AttrStmt*> channel_scope_;
// The loop index.
std::unordered_map<const Variable*, StageInput> var_info_;
// The channel entry;
std::unordered_map<const Variable*, ChannelEntry> cmap_;
// The argument handle map
std::unordered_map<const Variable*, Var> arg_handle_;
// The result block.
NodePtr<PipelineNode> pipeline_;
};
Pipeline MakePipeline(LoweredFunc f) {
return PipelineExtractor().Extract(f);
}
} // namespace verilog
} // namespace codegen
} // namespace tvm