| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file nnvm_to_onnx-inl.h |
| * \brief Conversion from NNVM to ONNX for TensorRT |
| * \author Marek Kolodziej, Clement Fuji Tsang |
| */ |
| |
| #ifndef MXNET_OPERATOR_SUBGRAPH_TENSORRT_NNVM_TO_ONNX_INL_H_ |
| #define MXNET_OPERATOR_SUBGRAPH_TENSORRT_NNVM_TO_ONNX_INL_H_ |
| |
| #if MXNET_USE_TENSORRT |
| |
| #include <mxnet/operator.h> |
| #include <nnvm/pass_functions.h> |
| |
| #include <onnx/onnx_pb.h> |
| |
| #include <unordered_map> |
| #include <vector> |
| #include <string> |
| |
| namespace mxnet { |
| namespace op { |
| namespace nnvm_to_onnx { |
| |
| enum ConvDeconvType { Convolution, Deconvolution }; |
| |
| using namespace nnvm; |
| using namespace ::onnx; |
| using int64 = ::google::protobuf::int64; |
| |
| std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs, |
| const nnvm::IndexedGraph& ig); |
| |
| std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector& dtype_inputs, |
| const nnvm::IndexedGraph& ig); |
| |
| std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGraph& ig); |
| |
| void ConvertPlaceholder(const std::string& node_name, |
| const std::unordered_map<std::string, TShape>& placeholder_shapes, |
| const std::unordered_map<std::string, int>& placeholder_dtypes, |
| GraphProto* graph_proto); |
| |
| void ConvertConstant(GraphProto* graph_proto, |
| const std::string& node_name, |
| const std::unordered_map<std::string, NDArray>* const params_map); |
| |
| void ConvertOutput(GraphProto* graph_proto, |
| const std::unordered_map<std::string, uint32_t>::iterator& out_iter, |
| const std::string& node_name, |
| const ShapeVector& shapes, |
| const DTypeVector& dtypes, |
| const nnvm::IndexedGraph& ig); |
| |
| void DefaultConnectInputsOutputs(const array_view<IndexedGraph::NodeEntry>& inputs, |
| const nnvm::IndexedGraph& ig, |
| const std::string& node_name); |
| |
| typedef void (*ConverterFunction)(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| template <class ConvDeconvParam> |
| void ConvDeconvConvertHelper(NodeProto* node_proto, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs, |
| const ConvDeconvParam& param, |
| ConvDeconvType type); |
| |
| // Forward declarations |
| void ConvertIdentity(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertConvolution(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertDeconvolution(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertPooling(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertRelu(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertActivation(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertFullyConnected(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertSlice(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertSoftmaxOutput(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertFlatten(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertDropout(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertBatchNorm(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertElementwiseAdd(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertElementwiseMul(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertElementwiseSub(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertConcatenate(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertClip(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| void ConvertPad(GraphProto* graph_proto, |
| const std::string& node_name, |
| const NodeAttrs& attrs, |
| const nnvm::IndexedGraph& ig, |
| const array_view<IndexedGraph::NodeEntry>& inputs); |
| |
| std::string ConvertNnvmGraphToOnnx(const nnvm::Graph& g, |
| std::unordered_map<std::string, NDArray>* params_map); |
| |
| static const std::unordered_map<std::string, ConverterFunction> converter_map = { |
| {"_copy", ConvertIdentity}, |
| {"Activation", ConvertActivation}, |
| {"BatchNorm", ConvertBatchNorm}, |
| {"clip", ConvertClip}, |
| {"Convolution", ConvertConvolution}, |
| {"Deconvolution", ConvertDeconvolution}, |
| {"Concat", ConvertConcatenate}, |
| {"Dropout", ConvertDropout}, |
| {"elemwise_add", ConvertElementwiseAdd}, |
| {"elemwise_mul", ConvertElementwiseMul}, |
| {"elemwise_sub", ConvertElementwiseSub}, |
| {"Flatten", ConvertFlatten}, |
| {"FullyConnected", ConvertFullyConnected}, |
| {"Pad", ConvertPad}, |
| {"Pooling", ConvertPooling}, |
| {"relu", ConvertRelu}, |
| {"slice", ConvertSlice}}; |
| |
| typedef void (*PreprocessFunction)(const NodeAttrs& attrs, |
| const std::vector<nnvm::NodeEntry>& inputs, |
| std::unordered_map<std::string, NDArray>* params_map); |
| |
| void PreprocessBatchNorm(const NodeAttrs& attrs, |
| const std::vector<nnvm::NodeEntry>& inputs, |
| std::unordered_map<std::string, NDArray>* params_map); |
| |
| static const std::unordered_map<std::string, PreprocessFunction> preprocess_map = { |
| {"BatchNorm", PreprocessBatchNorm}}; |
| |
| } // namespace nnvm_to_onnx |
| } // namespace op |
| } // namespace mxnet |
| |
| #endif // MXNET_USE_TENSORRT |
| |
| #endif // MXNET_OPERATOR_SUBGRAPH_TENSORRT_NNVM_TO_ONNX_INL_H_ |