blob: b5fc8d15f7acb5a207c8355a49385758ddb3538b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file tensorrt_pass.cc
* \brief Replace TRT compatible subgraphs by TRT engines
* \author Clement Fuji Tsang
*/
#if MXNET_USE_TENSORRT
#include <NvInfer.h>
#include <mxnet/base.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/operator.h>
#include <nnvm/graph_attr_types.h>
#include <onnx/onnx.pb.h>
#include "../operator/contrib/nnvm_to_onnx-inl.h"
#include "./exec_pass.h"
#include "./onnx_to_tensorrt.h"
namespace mxnet {
namespace exec {
using NodePtr = nnvm::NodePtr;
/*!
* \brief Custom graph class, which will contain bi-directional nodes
* we need to compute DFS and reverse DFS for graph partitioning
*/
class BidirectionalGraph {
public:
struct Node {
nnvm::Node* nnvmptr;
std::vector<Node*> inputs;
std::vector<Node*> outputs;
};
std::vector<Node> nodes;
std::unordered_map<nnvm::Node*, uint32_t> nnvm2nid;
std::vector<Node*> outputs;
static const std::unordered_set<std::string> unconditionalTRTop;
explicit BidirectionalGraph(const Graph &g) {
auto& idx = g.indexed_graph();
auto num_nodes = idx.num_nodes();
nodes.reserve(num_nodes);
nnvm2nid.reserve(num_nodes);
outputs.reserve(idx.outputs().size());
DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) {
BidirectionalGraph::Node new_node;
new_node.nnvmptr = n.get();
nnvm2nid[n.get()] = static_cast<uint32_t>(nodes.size());
nodes.emplace_back(std::move(new_node));
});
for (const auto& it : nnvm2nid) {
nnvm::Node* nnvmnode = it.first;
uint32_t nid = it.second;
for (auto& n : nnvmnode->inputs) {
uint32_t input_nid = nnvm2nid[n.node.get()];
nodes[input_nid].outputs.emplace_back(&nodes[nid]);
nodes[nid].inputs.emplace_back(&nodes[input_nid]);
}
}
for (auto& e : g.outputs) {
uint32_t nid = nnvm2nid[e.node.get()];
outputs.emplace_back(&nodes[nid]);
}
}
template <typename FVisit>
void DFS(const std::vector<Node*>& heads, bool reverse, FVisit fvisit) {
std::unordered_set<Node*> visited;
std::vector<Node*> vec(heads.begin(), heads.end());
visited.reserve(heads.size());
while (!vec.empty()) {
Node* vertex = vec.back();
vec.pop_back();
if (visited.count(vertex) == 0) {
visited.insert(vertex);
fvisit(vertex);
std::vector<Node*> nexts = reverse ? vertex->inputs : vertex->outputs;
for (Node* node : nexts) {
if (visited.count(node) == 0) {
vec.emplace_back(node);
}
}
}
}
}
using t_pairset = std::pair<std::unordered_set<Node*>, std::unordered_set<Node*>>;
using t_pairvec = std::pair<std::vector<Node*>, std::vector<Node*>>;
using t_uncomp_map = std::unordered_map<Node*, std::unordered_set<Node*>>;
std::unordered_set<Node*> naive_grow_subgraph(Node* head,
std::unordered_set<Node*>* set_unused,
t_uncomp_map* uncomp_map) {
std::unordered_set<Node*> subgraph;
std::unordered_set<Node*> uncomp_set;
std::deque<Node*> stack;
stack.emplace_back(head);
while (!stack.empty()) {
Node* vertex = stack.back();
stack.pop_back();
if (set_unused->count(vertex) && !uncomp_set.count(vertex)) {
set_unused->erase(vertex);
subgraph.insert(vertex);
uncomp_set.insert((*uncomp_map)[vertex].begin(), (*uncomp_map)[vertex].end());
for (Node* input : vertex->inputs) {
if (set_unused->count(input) && !uncomp_set.count(input)) {
stack.emplace_back(input);
}
}
for (Node* output : vertex->outputs) {
if (set_unused->count(output) && !uncomp_set.count(output)) {
stack.emplace_back(output);
}
}
}
}
return subgraph;
}
std::vector<std::unordered_set<Node*>> get_subsets(
std::unordered_map<std::string, NDArray>* const params_map) {
std::vector<std::unordered_set<Node*>> subgraphs;
std::unordered_set<Node*> set_nonTRTnodes;
std::unordered_set<Node*> set_allnodes(nodes.size());
std::vector<t_pairset> separation_sets;
for (Node& node : nodes) {
if (!IsTRTCompatible(node.nnvmptr)) {
set_nonTRTnodes.insert(&node);
std::unordered_set<Node*> in_graph;
std::unordered_set<Node*> out_graph;
std::vector<Node*> dummy_head;
dummy_head.emplace_back(&node);
DFS(dummy_head, false, [&out_graph](Node* node) {
out_graph.insert(node);
});
DFS(dummy_head, true, [&in_graph](Node* node) {
in_graph.insert(node);
});
separation_sets.emplace_back(std::make_pair(in_graph, out_graph));
}
set_allnodes.emplace(&node);
}
t_uncomp_map uncomp_map;
std::unordered_set<Node*> set_TRTnodes;
set_TRTnodes.insert(set_allnodes.begin(), set_allnodes.end());
for (Node* n : set_nonTRTnodes) {
set_TRTnodes.erase(n);
}
for (Node* n : set_TRTnodes) {
for (t_pairset p : separation_sets) {
if (p.first.count(n)) {
uncomp_map[n].insert(p.second.begin(), p.second.end());
} else if (p.second.count(n)) {
uncomp_map[n].insert(p.first.begin(), p.first.end());
}
}
for (Node* nonTRTn : set_nonTRTnodes) {
uncomp_map[n].erase(nonTRTn);
}
}
std::unordered_set<Node*> set_unused;
set_unused.reserve(set_TRTnodes.size());
for (auto& n : set_TRTnodes) {
if (n->nnvmptr->attrs.op != nullptr || params_map->count(n->nnvmptr->attrs.name)) {
set_unused.insert(n);
}
}
std::unordered_set<Node*> visited;
std::deque<Node*> stack(outputs.begin(), outputs.end());
while (!stack.empty()) {
Node* vertex = stack.front();
stack.pop_front();
if (!visited.count(vertex)) {
visited.insert(vertex);
if (set_unused.count(vertex)) {
subgraphs.emplace_back(naive_grow_subgraph(vertex, &set_unused, &uncomp_map));
}
for (Node* input : vertex->inputs) {
stack.emplace_back(input);
}
}
}
return subgraphs;
}
private:
friend class Graph;
bool IsTRTCompatible(nnvm::Node* nodeptr) {
if (nodeptr->op() == nullptr) {
return true;
}
const std::string op_name = nodeptr->op()->name;
if (op_name == "Pooling") {
return (nodeptr->attrs.dict.at("pool_type") == "avg" ||
nodeptr->attrs.dict.at("pool_type") == "max");
}
if (unconditionalTRTop.count(op_name)) {
return true;
}
if (op_name == "Activation") {
return nodeptr->attrs.dict.at("act_type") == "relu" ||
nodeptr->attrs.dict.at("act_type") == "tanh" ||
nodeptr->attrs.dict.at("act_type") == "sigmoid";
}
return false;
}
}; // class BidirectionalGraph
/*!
* \brief function which transform std::vector<dmlc::any> back to Attrs (dmlc::any)
*/
const std::unordered_set<std::string> BidirectionalGraph::unconditionalTRTop = {
"Convolution",
"BatchNorm",
"elemwise_add",
"elemwise_sub",
"elemwise_mul",
"rsqrt",
"pad",
"Pad",
"mean",
"FullyConnected",
"Flatten",
"SoftmaxOutput",
};
using NodeEntrySet = std::unordered_set<nnvm::NodeEntry, nnvm::NodeEntryHash,
nnvm::NodeEntryEqual>;
/*!
* \brief get the output nodes of the subgraph in the main graph
* \return a vector of the output nodes
*/
std::vector<nnvm::NodeEntry> GetSubgraphNodeEntries(Graph g,
std::unordered_set<nnvm::Node*> set_subgraph) {
std::vector<nnvm::NodeEntry> outputs;
NodeEntrySet _outputs;
for (auto& e : g.outputs) {
if (set_subgraph.count(e.node.get())) {
_outputs.insert(e);
}
}
DFSVisit(g.outputs, [&set_subgraph, &_outputs](const nnvm::NodePtr &node){
if (!set_subgraph.count(node.get())) {
for (auto& e : node->inputs) {
if (set_subgraph.count(e.node.get())) {
_outputs.insert(e);
}
}
}
});
outputs.insert(outputs.begin(), _outputs.begin(), _outputs.end());
return outputs;
}
/*!
* \brief get the nodes outside of the subgraph for which outputs are used in the subgraph
* \return a vector the nodes
*/
std::vector<nnvm::NodeEntry> GetSubgraphInterfaceNodes(Graph g,
std::unordered_set<nnvm::Node*> set_subgraph) {
std::vector<nnvm::NodeEntry> inputs;
NodeEntrySet _inputs;
DFSVisit(g.outputs, [&set_subgraph, &_inputs](const nnvm::NodePtr &node){
if (set_subgraph.count(node.get())) {
for (auto& e : node->inputs) {
if (!set_subgraph.count(e.node.get())) {
_inputs.insert(e);
}
}
}
});
inputs.insert(inputs.begin(), _inputs.begin(), _inputs.end());
return inputs;
}
std::unordered_map<uint32_t, uint32_t> GetGraphInputsMap(const Graph& g) {
std::unordered_map<uint32_t, uint32_t> outputs;
auto& idx = g.indexed_graph();
outputs.reserve(idx.num_nodes());
std::vector<uint32_t> input_nodes = idx.input_nodes();
for (size_t i = 0; i < input_nodes.size(); ++i) {
outputs[input_nodes[i]] = static_cast<uint32_t>(i);
}
return outputs;
}
/*!
* \brief Dummy function which creates a fake TensorRT Node
*/
nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
std::unordered_map<std::string, NDArray>* const params_map) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_trt_op");
op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
p->attrs.dict["serialized_input_map"] = trt_param.serialized_input_map;
p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
return p;
}
/*!
* \brief Update attributes of the graph (such as some inputs properties)
*/
Graph UpdateSubgraphAttrs(Graph&& subgraph, const Graph& g,
const std::unordered_map<nnvm::Node*, nnvm::NodePtr>& old2new,
const nnvm::NodeEntryMap<nnvm::NodeEntry>& main_input_entry_to_sub) {
const auto& idx = g.indexed_graph();
const auto& sub_idx = subgraph.indexed_graph();
const auto& shape = g.GetAttr<nnvm::ShapeVector>("shape");
const auto& dtype = g.GetAttr<nnvm::DTypeVector>("dtype");
const auto& storage_type = g.GetAttr<StorageTypeVector>("storage_type");
const auto& shape_inputs = g.GetAttr<nnvm::ShapeVector>("shape_inputs");
const auto& dtype_inputs = g.GetAttr<nnvm::DTypeVector>("dtype_inputs");
const auto& storage_type_inputs = g.GetAttr<StorageTypeVector>("storage_type_inputs");
nnvm::ShapeVector sub_shape(sub_idx.num_node_entries());
nnvm::DTypeVector sub_dtype(sub_idx.num_node_entries());
StorageTypeVector sub_storage_type(sub_idx.num_node_entries());
nnvm::ShapeVector sub_shape_inputs(sub_idx.input_nodes().size());
nnvm::DTypeVector sub_dtype_inputs(sub_idx.input_nodes().size());
StorageTypeVector sub_storage_type_inputs(sub_idx.input_nodes().size());
const std::unordered_map<uint32_t, uint32_t> inputsindex2pos = GetGraphInputsMap(g);
const std::unordered_map<uint32_t, uint32_t> sub_inputsindex2pos = GetGraphInputsMap(subgraph);
// map attributes from graph to subgraph
for (auto& p : old2new) {
const uint32_t nid = idx.node_id(p.first);
const uint32_t sub_nid = sub_idx.node_id(p.second.get());
const nnvm::Op* op = sub_idx[sub_nid].source->op();
if (op == nullptr) { // if it's an input node, there is only one output node entry
const uint32_t sub_i = sub_idx.entry_id(sub_nid, 0);
const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
const uint32_t i = idx.entry_id(nid, 0);
sub_shape[sub_i] = shape[i];
sub_dtype[sub_i] = dtype[i];
sub_storage_type[sub_i] = storage_type[i];
sub_shape_inputs[sub_input_i] = shape_inputs[inputsindex2pos.at(nid)];
sub_dtype_inputs[sub_input_i] = dtype_inputs[inputsindex2pos.at(nid)];
sub_storage_type_inputs[sub_input_i] = storage_type_inputs[inputsindex2pos.at(nid)];
} else {
for (size_t oi = 0; oi < op->num_outputs; ++oi) {
const uint32_t sub_i = sub_idx.entry_id(sub_nid, oi);
const uint32_t i = idx.entry_id(nid, oi);
sub_shape[sub_i] = shape[i];
sub_dtype[sub_i] = dtype[i];
sub_storage_type[sub_i] = storage_type[i];
}
}
}
// old2new doesn't contain placeholder / interfaces
for (auto& p : main_input_entry_to_sub) {
nnvm::NodeEntry main_entry = p.first;
nnvm::NodeEntry sub_entry = p.second;
const uint32_t sub_nid = sub_idx.node_id(sub_entry.node.get());
const uint32_t sub_i = sub_idx.entry_id(sub_entry);
const uint32_t i = idx.entry_id(main_entry);
const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid);
sub_shape[sub_i] = shape[i];
sub_dtype[sub_i] = dtype[i];
sub_storage_type[sub_i] = storage_type[i];
sub_shape_inputs[sub_input_i] = sub_shape[sub_i];
sub_dtype_inputs[sub_input_i] = sub_dtype[sub_i];
sub_storage_type_inputs[sub_input_i] = sub_storage_type[sub_i];
}
subgraph.attrs["shape"] =
std::make_shared<dmlc::any>(std::move(sub_shape));
subgraph.attrs["dtype"] =
std::make_shared<dmlc::any>(std::move(sub_dtype));
subgraph.attrs["storage_type"] =
std::make_shared<dmlc::any>(std::move(sub_storage_type));
subgraph.attrs["shape_inputs"] =
std::make_shared<dmlc::any>(std::move(sub_shape_inputs));
subgraph.attrs["dtype_inputs"] =
std::make_shared<dmlc::any>(std::move(sub_dtype_inputs));
subgraph.attrs["storage_type_inputs"] =
std::make_shared<dmlc::any>(std::move(sub_storage_type_inputs));
return subgraph;
}
/*!
* \brief Generate a name for a new TRT node, avoid collision if some TRT_nodes are already defined
*/
const std::string GetNewTrtName(const Graph& g, const Graph& subgraph) {
const std::string name_prefix("TRT_node");
std::unordered_set<std::string> name_set;
DFSVisit(g.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
name_set.insert(node->attrs.name);
}
});
// name inside the subgraph will be avaible as they will be removed
DFSVisit(subgraph.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) {
if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) {
name_set.erase(node->attrs.name);
}
});
uint32_t name_suffix = 0;
std::string full_name = name_prefix + std::to_string(name_suffix);
while (name_set.count(full_name)) {
full_name = name_prefix + std::to_string(++name_suffix);
}
return full_name;
}
/*!
* \brief helper function to display what nodes are in a specific subset
*/
void dispNodesSet(Graph g, std::unordered_set<nnvm::Node*> s) {
DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){
if (s.count(n.get())) {
std::cout << " Y " << n->attrs.name << std::endl;
} else {
std::cout << " N " << n->attrs.name << std::endl;
}
});
}
/*!
* \brief Replace a set of nodes by a TensorRT node
*/
Graph ReplaceSubgraph(Graph&& g,
const std::unordered_set<nnvm::Node*>& set_subgraph,
std::unordered_map<std::string, NDArray>* const params_map) {
// Create MXNet subgraph
Graph subgraph;
const auto sub_outputs_in_main = GetSubgraphNodeEntries(g, set_subgraph);
subgraph.outputs = sub_outputs_in_main;
// old2new will link raw pointer of the nodes in the graph to
// the corresponding shared_ptr of the nodes in the generated subgraph
std::unordered_map<nnvm::Node*, nnvm::NodePtr> old2new;
std::deque<nnvm::Node*> stack;
std::unordered_set<nnvm::Node*> visited;
int32_t reservation = set_subgraph.size();
old2new.reserve(reservation);
visited.reserve(reservation);
// Create the shared_ptr using the same raw pointer don't really matter
for (auto& n : set_subgraph) {
old2new[n] = std::make_shared<nnvm::Node>(*n);
}
// To generate a subgraph an input have to be replace by data node (no op)
// and it have to be agnostic to the node from which it's an output
// (For exemple even if two inputs are two different outputs from the same node)
nnvm::NodeEntryMap<nnvm::NodeEntry> main_input_entry_to_sub;
for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) {
auto node = nnvm::Node::Create();
node->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index);
auto new_e = nnvm::NodeEntry{node, 0, 0};
main_input_entry_to_sub[e] = new_e;
}
for (nnvm::NodeEntry& e : subgraph.outputs) {
e.node = old2new[e.node.get()];
stack.emplace_back(e.node.get());
}
// link all nodes in the subgraph to nodes in the subgraph instead of main graph
while (!stack.empty()) {
auto vertex = stack.front();
stack.pop_front();
if (!visited.count(vertex)) {
visited.insert(vertex);
for (auto& e : vertex->inputs) {
auto it = main_input_entry_to_sub.find(e);
if (it != main_input_entry_to_sub.end()) {
e = it->second;
} else {
e.node = old2new[e.node.get()];
}
stack.emplace_back(e.node.get());
}
}
}
// Remove the control dependencies of the subgraph to nodes that are not in the subgraph
DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) {
std::remove_if(node->control_deps.begin(),
node->control_deps.end(),
[&set_subgraph](nnvm::NodePtr n_ptr) {
return !set_subgraph.count(n_ptr.get());
});
for (nnvm::NodePtr& n_ptr : node->control_deps) {
n_ptr = old2new[n_ptr.get()];
}
});
subgraph = UpdateSubgraphAttrs(std::move(subgraph), g, old2new, main_input_entry_to_sub);
auto& sub_idx = subgraph.indexed_graph();
auto trtnodeptr = ConvertNnvmGraphToOnnx(subgraph, params_map);
trtnodeptr->attrs.name = GetNewTrtName(g, subgraph);
// Insert new trt node and unplug replaced nodes
std::unordered_map<uint32_t, nnvm::NodeEntry> sub_input_entryid_to_main;
for (auto& p : main_input_entry_to_sub) {
sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first;
}
// Plug the nodes from the main graph as inputs of the trt node
trtnodeptr->inputs.resize(main_input_entry_to_sub.size());
{
uint32_t counter = 0;
for (uint32_t i : sub_idx.input_nodes()) {
auto it = sub_input_entryid_to_main.find(sub_idx.entry_id(i, 0));
if (it != sub_input_entryid_to_main.end()) {
trtnodeptr->inputs[counter++] = it->second;
}
}
}
nnvm::NodeEntryMap<uint32_t> sub_outputs_in_main_to_pos;
for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) {
sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i;
}
// Plug the trt node as inputs to the main graph nodes
DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) {
for (auto& e : n->inputs) {
auto it = sub_outputs_in_main_to_pos.find(e);
if (it != sub_outputs_in_main_to_pos.end()) {
e.index = it->second;
e.node = trtnodeptr;
}
}
});
for (auto& output : g.outputs) {
auto it = sub_outputs_in_main_to_pos.find(output);
if (it != sub_outputs_in_main_to_pos.end()) {
output.index = it->second;
output.node = trtnodeptr;
}
}
Graph new_graph;
new_graph.outputs = g.outputs;
return new_graph;
}
std::vector<std::unordered_set<nnvm::Node*>> GetTrtCompatibleSubsets(const Graph& g,
std::unordered_map<std::string, NDArray>* const params_map) {
BidirectionalGraph biG = BidirectionalGraph(g);
std::vector<std::unordered_set<BidirectionalGraph::Node*>> subsets = biG.get_subsets(params_map);
std::vector<std::unordered_set<nnvm::Node*>> nnvm_subsets(subsets.size(),
std::unordered_set<nnvm::Node*>());
for (size_t i = 0; i < subsets.size(); ++i) {
nnvm_subsets[i].reserve(subsets[i].size());
for (auto& n : subsets[i]) {
nnvm_subsets[i].insert(n->nnvmptr);
}
}
return nnvm_subsets;
}
} // namespace exec
} // namespace mxnet
#endif // MXNET_USE_TENSORRT