blob: 564401b083912617e5e6098afce428c0d94f7184 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tensorrt-inl.h
* \brief TensorRT operation registration
* \author Marek Kolodziej, Clement Fuji Tsang
*/
#ifndef MXNET_OPERATOR_SUBGRAPH_TENSORRT_TENSORRT_INL_H_
#define MXNET_OPERATOR_SUBGRAPH_TENSORRT_TENSORRT_INL_H_
#if MXNET_USE_TENSORRT
#include <onnx-tensorrt/NvOnnxParser.h>
#include <utility>
#include <string>
#include <vector>
#include "../../nn/activation-inl.h"
#include "../../nn/batch_norm-inl.h"
#include "../../nn/concat-inl.h"
#include "../../nn/convolution-inl.h"
#include "../../nn/deconvolution-inl.h"
#include "../../nn/dropout-inl.h"
#include "../../nn/fully_connected-inl.h"
#include "../../nn/pooling-inl.h"
#include "../common.h"
#include "../subgraph_property.h"
#include "nnvm_to_onnx-inl.h"
#include "./onnx_to_tensorrt.h"
namespace mxnet {
namespace op {
using int64 = ::google::protobuf::int64;
struct TRTParam {
std::unordered_map<std::string, uint32_t> inputs_to_idx;
std::unordered_map<std::string, uint32_t> outputs_to_idx;
std::unordered_map<std::string, NDArray> params_map;
};
struct TRTEngineParam {
TRTEngineParam(onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> _trt_engine,
onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser> _trt_parser,
std::unique_ptr<onnx_to_tensorrt::TRT_Logger> _trt_logger,
const std::unordered_map<std::string, uint32_t>& input_map,
const std::unordered_map<std::string, uint32_t>& output_map) {
trt_engine = std::move(_trt_engine);
trt_logger = std::move(_trt_logger);
trt_parser = std::move(_trt_parser);
binding_order = std::make_shared<std::vector<std::pair<uint32_t, bool>>>();
bindings = std::make_shared<std::vector<void*>>();
binding_order->reserve(trt_engine->getNbBindings());
bindings->resize(trt_engine->getNbBindings());
for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
const std::string& binding_name = trt_engine->getBindingName(b);
if (trt_engine->bindingIsInput(b)) {
binding_order->emplace_back(input_map.at(binding_name), true);
} else {
binding_order->emplace_back(output_map.at(binding_name), false);
}
}
trt_executor = onnx_to_tensorrt::InferObject(trt_engine->createExecutionContext());
}
onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
onnx_to_tensorrt::unique_ptr<nvinfer1::IExecutionContext> trt_executor;
onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser> trt_parser;
std::unique_ptr<onnx_to_tensorrt::TRT_Logger> trt_logger;
std::shared_ptr<std::vector<std::pair<uint32_t, bool>>> binding_order;
std::shared_ptr<std::vector<void*>> bindings;
};
class TensorrtSelector : public SubgraphSelector {
public:
const std::unordered_set<std::string> unconditionalTRTops = {
"_copy",
"clip",
"elemwise_add",
"elemwise_sub",
"elemwise_mul",
"Flatten",
"Pad",
"relu",
"rsqrt",
};
const std::unordered_set<std::string> withWeightsOps = {"BatchNorm",
"Convolution",
"Deconvolution",
"FullyConnected"};
bool isTRTCompatible(const nnvm::Node& n) {
const std::string op_name = n.op()->name;
if (op_name == "FullyConnected") {
const auto& param = nnvm::get<FullyConnectedParam>(n.attrs.parsed);
return !param.no_bias;
}
if (op_name == "Pooling") {
const auto& param = nnvm::get<PoolingParam>(n.attrs.parsed);
if (param.layout.has_value()) {
if (param.layout.value() == mshadow::kNHWC) {
LOG(INFO) << "Warning: NHWC layout (node: " << n.attrs.name
<< ") is not supported by TensorRT";
return false;
} else if (param.layout.value() == mshadow::kNDHWC) {
LOG(INFO) << "Warning: NDHWC layout (node: " << n.attrs.name
<< ") is not supported by TensorRT";
return false;
}
}
if (param.pooling_convention != pool_enum::kValid && !param.global_pool)
return false;
if (param.pool_type == pool_enum::kAvgPooling) {
if ((!param.global_pool) &&
(!param.count_include_pad.has_value() || param.count_include_pad.value()))
return false;
return true;
} else if (param.pool_type == pool_enum::kMaxPooling) {
return true;
} else {
return false;
}
}
if (op_name == "Convolution") {
const auto& param = nnvm::get<ConvolutionParam>(n.attrs.parsed);
if (!param.layout.has_value())
return true;
switch (param.layout.value()) {
case mshadow::kNCHW:
case mshadow::kNCW:
case mshadow::kNCDHW:
return true;
case mshadow::kNHWC:
LOG(INFO) << "Warning: NHWC layout (node: " << n.attrs.name
<< ") is not supported by TensorRT";
return false;
case mshadow::kNDHWC:
LOG(INFO) << "Warning: NDHWC layout (node: " << n.attrs.name
<< ") is not supported by TensorRT";
return false;
default:
LOG(INFO) << "Warning: Layout (node: " << n.attrs.name
<< ") is unknown (so unsupported by TensorRT)";
return false;
}
}
if (op_name == "Deconvolution") {
const auto& param = nnvm::get<DeconvolutionParam>(n.attrs.parsed);
if (!param.layout.has_value())
return true;
switch (param.layout.value()) {
case mshadow::kNCHW:
case mshadow::kNCW:
case mshadow::kNCDHW:
return true;
case mshadow::kNHWC:
LOG(INFO) << "Warning: NHWC layout (node: " << n.attrs.name
<< ") is no tsupported by TensorRT";
return false;
case mshadow::kNDHWC:
LOG(INFO) << "Warning: NDHWC layout (node: " << n.attrs.name
<< ") is not supported by TensorRT";
return false;
default:
LOG(INFO) << "Warning: Layout (node: " << n.attrs.name
<< ") is unknown (so unsupported by TensorRT)";
return false;
}
}
if (op_name == "Concat") {
const auto& param = nnvm::get<ConcatParam>(n.attrs.parsed);
return (param.dim != 0);
}
if (op_name == "Dropout") {
const auto& param = nnvm::get<DropoutParam>(n.attrs.parsed);
return param.mode == dropout::kTraining && param.axes.ndim() == 0;
}
if (op_name == "Activation") {
return n.attrs.dict.at("act_type") == "relu" || n.attrs.dict.at("act_type") == "tanh" ||
n.attrs.dict.at("act_type") == "sigmoid";
}
if (op_name == "BatchNorm") {
const auto& param = nnvm::get<BatchNormParam>(n.attrs.parsed);
if (param.axis != 1) {
LOG(INFO) << "Warning: Only Layout NC(D)(H)W are supported by TensorRT "
<< "(node " << n.attrs.name << ")";
return false;
}
return true;
}
if (unconditionalTRTops.count(op_name)) {
return true;
}
return false;
}
bool Select(const nnvm::Node& n) override {
return !n.is_variable() && isTRTCompatible(n);
}
bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
if (new_node.is_variable()) {
if (withWeightsOps.count(n.op()->name)) {
return n.inputs[0].node->attrs.name != new_node.attrs.name;
} else {
return false;
}
}
return isTRTCompatible(new_node);
}
bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
return isTRTCompatible(new_node);
}
std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) override {
bool found_one = false;
// TensorRT is interesting with at least 2 operations
for (auto& n : candidates) {
if (!n->is_variable()) {
if (found_one) {
return candidates;
} else {
found_one = true;
}
}
}
return std::vector<nnvm::Node*>();
}
};
class TensorrtProperty : public SubgraphProperty {
public:
static SubgraphPropertyPtr Create() {
return std::make_shared<TensorrtProperty>();
}
void PrePartition(const nnvm::Graph& g,
const std::unordered_map<std::string, std::string>& options_map) override {
auto& in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
auto& in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
NDArray** in_args_ptr = g.GetAttr<NDArray**>("in_args");
NDArray** in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
in_args_dict.clear();
in_aux_dict.clear();
// we trust the Python API, len(in_arg_names) == len(in_args_ptr)
for (unsigned i = 0; i < in_arg_names.size(); ++i) {
in_args_dict[in_arg_names[i]] = in_args_ptr[i];
}
for (unsigned i = 0; i < in_aux_names.size(); ++i) {
in_aux_dict[in_aux_names[i]] = in_aux_ptr[i];
}
}
nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
const int subgraph_id) const override {
nnvm::ObjectPtr n = nnvm::Node::Create();
nnvm::Symbol new_sym;
std::unique_copy(sym.outputs.begin(),
sym.outputs.end(),
std::back_inserter(new_sym.outputs),
[](nnvm::NodeEntry lhs, nnvm::NodeEntry rhs) {
return lhs.index == rhs.index && lhs.node.get() == rhs.node.get();
});
n->attrs.name = "TensorRT" + std::to_string(subgraph_id);
n->attrs.op = Op::Get("_TensorRT");
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
// Mapping subgraph params with NDArrays
TRTParam param;
std::ostringstream params_oss;
for (auto& param_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
NDArray* cache = nullptr;
auto it_args = in_args_dict.find(param_name);
if (it_args != in_args_dict.end()) {
cache = it_args->second;
} else {
auto it_aux = in_aux_dict.find(param_name);
if (it_aux != in_aux_dict.end()) {
cache = it_aux->second;
}
}
if (cache != nullptr) {
param.params_map.emplace(param_name, cache->Copy(Context()));
param.params_map[param_name].WaitToRead();
params_oss << param_name << ";";
}
}
auto tensorrt_params_names = params_oss.str();
if (!tensorrt_params_names.empty()) {
tensorrt_params_names.pop_back();
}
n->attrs.parsed = param;
n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
return n;
}
SubgraphSelectorPtr CreateSubgraphSelector() const override {
return std::make_shared<TensorrtSelector>();
}
void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
std::vector<nnvm::NodeEntry*>* output_entries) const override {
std::vector<nnvm::NodeEntry>& outputs = subgraph_node->attrs.subgraphs[0]->outputs;
TRTParam& _params = nnvm::get<TRTParam>(subgraph_node->attrs.parsed);
for (size_t i = 0; i < outputs.size(); i++) {
auto& o = outputs[i];
for (auto& e : *output_entries) {
if (o.index == e->index && o.node.get() == e->node.get()) {
e->index = i;
e->node = subgraph_node;
// TODO(cfujitsang): For future support this would fail
// if the node have multiple outputs
_params.outputs_to_idx[o.node->attrs.name] = i;
}
}
}
subgraph_node->attrs.parsed = std::move(_params);
}
void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node,
std::vector<nnvm::NodeEntry*>* input_entries,
std::vector<nnvm::NodeEntry>* orig_input_entries) const override {
TRTParam& _params = nnvm::get<TRTParam>(subgraph_node->attrs.parsed);
subgraph_node->inputs.clear();
subgraph_node->inputs.resize(orig_input_entries->size());
for (size_t i = 0; i < orig_input_entries->size(); ++i) {
subgraph_node->inputs[i] = orig_input_entries->at(i);
_params.inputs_to_idx[input_entries->at(i)->node->attrs.name] = i;
}
subgraph_node->attrs.parsed = std::move(_params);
}
std::unordered_map<std::string, NDArray*> in_args_dict, in_aux_dict;
};
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_TENSORRT
#endif // MXNET_OPERATOR_SUBGRAPH_TENSORRT_TENSORRT_INL_H_