blob: f9f1961e26973eb5e68dede4784387317628ba5a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/contrib/dnnl/dnnl_json_runtime.cc
* \brief A simple JSON runtime for DNNL.
*/
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
#include <cstddef>
#include <string>
#include <vector>
#include "../json/json_node.h"
#include "../json/json_runtime.h"
#include "dnnl.hpp"
namespace tvm {
namespace runtime {
namespace contrib {
using namespace tvm::runtime;
using namespace tvm::runtime::json;
class DNNLJSONRuntime : public JSONRuntimeBase {
using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;
public:
DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
const Array<String> const_names)
: JSONRuntimeBase(symbol_name, graph_json, const_names) {}
const char* type_key() const { return "dnnl_json"; }
void Init(const Array<NDArray>& consts) override {
BuildEngine();
ICHECK_EQ(consts.size(), const_idx_.size())
<< "The number of input constants must match the number of required.";
// Setup constants entries for weights.
SetupConstants(consts);
}
void Run() override {
// Fill in the input buffers.
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto eid = EntryID(input_nodes_[i], 0);
// TODO(@comaniac): Support other data lengths.
size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
size_t buffer_size = GetDataSize(*data_entry_[eid]);
write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
offset_in_bytes);
}
// Invoke the engine through intepreting the stream.
for (size_t i = 0; i < net_.size(); ++i) {
net_.at(i).execute(stream_, net_args_.at(i));
}
stream_.wait();
// Read output buffers.
for (size_t i = 0; i < outputs_.size(); ++i) {
auto eid = EntryID(outputs_[i]);
size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
size_t buffer_size = GetDataSize(*data_entry_[eid]);
read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
offset_in_bytes);
}
}
private:
// Build up the engine based on the input graph.
void BuildEngine() {
engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0);
stream_ = dnnl::stream(engine_);
// Build subgraph engine.
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
ICHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if ("nn.conv2d" == op_name) {
Conv2d(nid);
} else if ("dnnl.conv2d_relu" == op_name) {
Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu);
} else if ("dnnl.conv2d_tanh" == op_name) {
Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh);
} else if ("dnnl.conv2d_sigmoid" == op_name) {
Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic);
} else if ("dnnl.conv2d_bias" == op_name) {
Conv2d(nid, false, true);
} else if ("dnnl.conv2d_bias_relu" == op_name) {
Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu);
} else if ("dnnl.conv2d_bias_tanh" == op_name) {
Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh);
} else if ("dnnl.conv2d_bias_sigmoid" == op_name) {
Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic);
} else if ("nn.dense" == op_name) {
Dense(nid);
} else if ("dnnl.dense_bias" == op_name) {
Dense(nid, true);
} else if ("nn.batch_norm" == op_name) {
BatchNorm(nid);
} else if ("nn.relu" == op_name) {
Eltwise(nid, dnnl::algorithm::eltwise_relu);
} else if ("tanh" == op_name) {
Eltwise(nid, dnnl::algorithm::eltwise_tanh);
} else if ("sigmoid" == op_name) {
Eltwise(nid, dnnl::algorithm::eltwise_logistic);
} else if ("add" == op_name) {
Binary(nid, dnnl::algorithm::binary_add);
} else if ("multiply" == op_name) {
Binary(nid, dnnl::algorithm::binary_mul);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
}
}
}
// Bind a JSON graph node entry to a DNNL memory.
dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory::desc mem_desc,
size_t offset = 0) {
auto eid = EntryID(entry);
if (entry_out_mem_.count(eid) == 0) {
return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset);
}
return entry_out_mem_[eid].first;
}
// Bind a JSON graph node entry to a given DNNL memory.
dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory mem,
size_t offset = 0) {
auto eid = EntryID(entry);
// Since the DNNL memory has been created before calling this function, we assume the entry
// has not yet been bound to the other DNNL memory; otherwise it may have memory leak.
ICHECK_EQ(entry_out_mem_.count(eid), 0);
// TODO(@comanic): Support other data types (i.e., int8).
auto data_node = nodes_[entry.id_];
auto dltype = data_node.GetOpDataType()[entry.index_];
ICHECK_EQ(dltype.bits, 32);
entry_out_mem_[eid] = {mem, offset};
return entry_out_mem_[eid].first;
}
void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false,
dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) {
auto node = nodes_[nid];
// Setup attributes.
auto data_entry = node.GetInputs()[0];
auto weight_entry = node.GetInputs()[1];
dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
std::vector<std::string> str_strides = node.GetAttr<std::vector<std::string>>("strides");
std::vector<std::string> str_dilates = node.GetAttr<std::vector<std::string>>("dilation");
std::vector<std::string> str_padding = node.GetAttr<std::vector<std::string>>("padding");
dnnl::memory::dim groups = std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
dnnl::memory::dim N = input_shape[0], // batch size
IC = input_shape[1], // input channels
IH = input_shape[2], // input height
IW = input_shape[3], // input width
OC = weight_shape[0], // output channels
KH = weight_shape[2], // weight height
KW = weight_shape[3], // weight width
PW_L = std::stoi(str_padding[1]), // width padding: left
PW_R = std::stoi(str_padding[3]), // width padding: right
PH_L = std::stoi(str_padding[0]), // height padding: top
PH_R = std::stoi(str_padding[2]), // height padding: bottom
SH = std::stoi(str_strides[0]), // height-wise stride
SW = std::stoi(str_strides[1]), // weight-wise stride
DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate
DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate
DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height
DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width
OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height
OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width
// Memory shapes.
dnnl::memory::dims src_dims = {N, IC, IH, IW};
dnnl::memory::dims weights_dims = {OC, IC, KH, KW};
if (groups > 1) {
weights_dims = {groups, 1, IC / groups, KH, KW};
}
dnnl::memory::dims bias_dims = {OC};
dnnl::memory::dims dst_dims = {N, OC, OH, OW};
dnnl::memory::dims strides_dims = {SH, SW};
dnnl::memory::dims dilates_dims = {DH, DW};
dnnl::memory::dims padding_dims_l = {PH_L, PW_L};
dnnl::memory::dims padding_dims_r = {PH_R, PW_R};
// Memory descriptions.
auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, tag::any);
auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, tag::any);
auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::nchw);
// Covn2d description.
auto conv_desc = dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l,
padding_dims_r);
// Enable elementwise post-ops
dnnl::primitive_attr attr;
if (has_elt) {
dnnl::post_ops ops;
ops.append_eltwise(1.f, algo, 0.f, 0.f);
attr.set_post_ops(ops);
}
auto conv2d_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_);
// Push to the network.
auto conv = dnnl::convolution_forward(conv2d_prim_desc);
net_.push_back(conv);
// Data memory.
ICHECK_EQ(node.GetAttr<std::vector<std::string>>("data_layout")[0], "NCHW");
auto conv2d_src_memory = BindDNNLMemory(data_entry, {src_dims, dt::f32, tag::nchw});
// Weight memory.
ICHECK_EQ(node.GetAttr<std::vector<std::string>>("kernel_layout")[0], "OIHW");
auto conv2d_weights_memory = BindDNNLMemory(
weight_entry, {weights_dims, dt::f32, (groups > 1) ? tag::goihw : tag::oihw});
// Bias memory.
auto conv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_);
if (has_bias) {
auto bias_entry = node.GetInputs()[2];
BindDNNLMemory(bias_entry, conv2d_bias_memory);
} else {
float bias[OC] = {0};
write_to_dnnl_memory(bias, conv2d_bias_memory, OC * sizeof(float));
}
// Output memory.
JSONGraphNodeEntry out_entry(nid, 0);
auto conv2d_dst_memory = BindDNNLMemory(out_entry, conv2d_prim_desc.dst_desc());
// Bind memory buffers.
net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory},
{DNNL_ARG_WEIGHTS, conv2d_weights_memory},
{DNNL_ARG_BIAS, conv2d_bias_memory},
{DNNL_ARG_DST, conv2d_dst_memory}});
}
void Dense(const size_t& nid, const bool has_bias = false) {
auto node = nodes_[nid];
// Setup attributes.
auto data_entry = node.GetInputs()[0];
auto weight_entry = node.GetInputs()[1];
dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
dnnl::memory::dim B = input_shape[0], // batch size
IC = input_shape[1], // input channels
OC = weight_shape[0]; // output channels
// Memory shapes.
dnnl::memory::dims data_dims = {B, IC};
dnnl::memory::dims weight_dims = {OC, IC};
dnnl::memory::dims bias_dims = {OC};
dnnl::memory::dims out_dims = {B, OC};
// Memory descriptions.
auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc});
auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc});
auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x});
auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc});
// Dense description.
auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md,
weight_md, bias_md, dst_md);
auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, engine_);
auto dense = dnnl::inner_product_forward(dense_prim_desc);
net_.push_back(dense);
// Memories.
auto data_memory = BindDNNLMemory(data_entry, data_md);
auto weight_memory = BindDNNLMemory(weight_entry, weight_md);
// Bias memory.
auto bias_memory = dnnl::memory(bias_md, engine_);
if (has_bias) {
auto bias_entry = node.GetInputs()[2];
BindDNNLMemory(bias_entry, bias_memory);
} else {
float bias[OC] = {0};
write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
}
// Output memory.
JSONGraphNodeEntry out_entry(nid, 0);
auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());
net_args_.push_back({{DNNL_ARG_SRC, data_memory},
{DNNL_ARG_WEIGHTS, weight_memory},
{DNNL_ARG_BIAS, bias_memory},
{DNNL_ARG_DST, dst_memory}});
}
void BatchNorm(const size_t& nid) {
auto node = nodes_[nid];
auto data_entry = node.GetInputs()[0];
auto gamma_entry = node.GetInputs()[1];
auto beta_entry = node.GetInputs()[2];
auto mean_entry = node.GetInputs()[3];
auto variance_entry = node.GetInputs()[4];
dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::dim IC = data_shape[1];
float epsilon = std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
// Memory description.
dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
// BN description.
auto bn_desc = dnnl::batch_normalization_forward::desc(
dnnl::prop_kind::forward_inference, data_md, epsilon,
dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift);
auto bn_prim_desc = dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_);
auto bn = dnnl::batch_normalization_forward(bn_prim_desc);
net_.push_back(bn);
// Memories.
auto data_memory = BindDNNLMemory(data_entry, data_md);
JSONGraphNodeEntry out_entry(nid, 0);
auto out_memory = BindDNNLMemory(out_entry, data_md);
auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc());
auto variance_memory = BindDNNLMemory(variance_entry, bn_prim_desc.variance_desc());
// In DNNL, weight is composed of gamma+beta, so we point them to the same DNNL memory but
// assign an offset to beta data for runtime serialization.
auto weight_memory = BindDNNLMemory(gamma_entry, bn_prim_desc.weights_desc(), 0);
BindDNNLMemory(beta_entry, weight_memory, IC);
net_args_.push_back({{DNNL_ARG_SRC, data_memory},
{DNNL_ARG_DST, out_memory},
{DNNL_ARG_SCALE_SHIFT, weight_memory},
{DNNL_ARG_MEAN, mean_memory},
{DNNL_ARG_VARIANCE, variance_memory}});
}
void Eltwise(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];
auto data_entry = node.GetInputs()[0];
dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
auto elt_desc =
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0);
auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_);
ICHECK(data_md == elt_prim_desc.dst_desc());
auto elt = dnnl::eltwise_forward(elt_prim_desc);
net_.push_back(elt);
auto data_memory = BindDNNLMemory(data_entry, data_md);
JSONGraphNodeEntry out_entry(nid, 0);
auto out_memory = BindDNNLMemory(out_entry, data_md);
net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}});
}
void Binary(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];
// Memory and compute description.
std::vector<dnnl::memory::dims> data_dims;
std::vector<dnnl::memory::desc> data_mds;
std::vector<dnnl::memory> data_memories;
ICHECK_EQ(node.GetInputs().size(), 2U);
for (auto entry : node.GetInputs()) {
auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_];
dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
data_dims.push_back(data_shape);
data_mds.push_back(data_md);
data_memories.push_back(BindDNNLMemory(entry, data_md));
}
ICHECK(data_dims[0] == data_dims[1]);
auto out_md = data_mds[0];
JSONGraphNodeEntry out_entry(nid, 0);
auto out_memory = BindDNNLMemory(out_entry, out_md);
auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md);
auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_);
auto binary = dnnl::binary(binary_prim_desc);
net_.push_back(binary);
net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]},
{DNNL_ARG_SRC_1, data_memories[1]},
{DNNL_ARG_DST, out_memory}});
}
// Read from DNNL memory (+offset) and write to the handle.
inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size,
size_t offset = 0) {
uint8_t* src = static_cast<uint8_t*>(mem.get_data_handle());
std::copy(src + offset, src + offset + size, static_cast<uint8_t*>(handle));
}
// Read from the handle and write to DNNL memory (+offset).
inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size,
size_t offset = 0) {
uint8_t* dst = static_cast<uint8_t*>(mem.get_data_handle());
std::copy(reinterpret_cast<uint8_t*>(handle), reinterpret_cast<uint8_t*>(handle) + size,
dst + offset);
}
// Generate DNNL memory description and infer the data layout by the given shape.
inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, dt dtype) {
dnnl::memory::desc data_md;
switch (shape.size()) {
case 2:
data_md = dnnl::memory::desc({shape, dtype, tag::ab});
break;
case 3:
data_md = dnnl::memory::desc({shape, dtype, tag::abc});
break;
case 4:
data_md = dnnl::memory::desc({shape, dtype, tag::abcd});
break;
case 5:
data_md = dnnl::memory::desc({shape, dtype, tag::abcde});
break;
default:
LOG(FATAL) << "Unsupported data shape dimension: " << shape.size();
break;
}
return data_md;
}
/* The dnnl engine. */
dnnl::engine engine_;
/* The dnnl stream. */
dnnl::stream stream_;
/* The network layers that are represented in dnnl primitives. */
std::vector<dnnl::primitive> net_;
/* The memory that is consumed by arguments. */
std::vector<std::unordered_map<int, dnnl::memory>> net_args_;
/* The entry ID to its corresponding output memory. */
std::unordered_map<uint32_t, std::pair<dnnl::memory, size_t>> entry_out_mem_;
};
runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json,
const Array<String>& const_names) {
auto n = make_object<DNNLJSONRuntime>(symbol_name, graph_json, const_names);
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate);
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json")
.set_body_typed(JSONRuntimeBase::LoadFromBinary<DNNLJSONRuntime>);
} // namespace contrib
} // namespace runtime
} // namespace tvm