| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file contrib.h |
| * \brief utility function to enable some contrib features |
| * \author Haohuan Wang |
| */ |
| #ifndef MXNET_CPP_CONTRIB_H_ |
| #define MXNET_CPP_CONTRIB_H_ |
| |
| #include <iostream> |
| #include <string> |
| #include <map> |
| #include <vector> |
| #include "mxnet-cpp/symbol.h" |
| |
| namespace mxnet { |
| namespace cpp { |
| namespace details { |
| |
| /*! |
| * split a string with the given delimiter |
| * @param str string to be parsed |
| * @param delimiter delimiter |
| * @return delimited list of string |
| */ |
| inline std::vector<std::string> split(const std::string& str, const std::string& delimiter) { |
| std::vector<std::string> splitted; |
| size_t last = 0; |
| size_t next = 0; |
| while ((next = str.find(delimiter, last)) != std::string::npos) { |
| splitted.push_back(str.substr(last, next - last)); |
| last = next + 1; |
| } |
| splitted.push_back(str.substr(last)); |
| return splitted; |
| } |
| |
| } // namespace details |
| |
| namespace contrib { |
| |
| // needs to be same with |
| // https://github.com/apache/mxnet/blob/1c874cfc807cee755c38f6486e8e0f4d94416cd8/src/operator/subgraph/tensorrt/tensorrt-inl.h#L190 |
| static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER = "subgraph_params_names"; // NOLINT |
| // needs to be same with |
| // https://github.com/apache/mxnet/blob/master/src/operator/subgraph/tensorrt/tensorrt.cc#L244 |
| static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX = "subgraph_param_"; // NOLINT |
| /*! |
| * this is a mimic to |
| * https://github.com/apache/mxnet/blob/master/python/mxnet/contrib/tensorrt.py#L37 |
| * @param symbol symbol that already called subgraph api |
| * @param argParams original arg params, params needed by tensorrt will be removed after calling |
| * this function |
| * @param auxParams original aux params, params needed by tensorrt will be removed after calling |
| * this function |
| */ |
| inline void InitTensorRTParams(const mxnet::cpp::Symbol& symbol, |
| std::map<std::string, mxnet::cpp::NDArray>* argParams, |
| std::map<std::string, mxnet::cpp::NDArray>* auxParams) { |
| mxnet::cpp::Symbol internals = symbol.GetInternals(); |
| mx_uint numSymbol = internals.GetNumOutputs(); |
| for (mx_uint i = 0; i < numSymbol; ++i) { |
| std::map<std::string, std::string> attrs = internals[i].ListAttributes(); |
| if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) { |
| std::string new_params_names; |
| std::map<std::string, mxnet::cpp::NDArray> tensorrtParams; |
| std::vector<std::string> keys = |
| details::split(attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER], ";"); |
| for (const auto& key : keys) { |
| if (argParams->find(key) != argParams->end()) { |
| new_params_names += key + ";"; |
| tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key]; |
| argParams->erase(key); |
| } else if (auxParams->find(key) != auxParams->end()) { |
| new_params_names += key + ";"; |
| tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key]; |
| auxParams->erase(key); |
| } |
| } |
| std::map<std::string, std::string> new_attrs = {}; |
| for (const auto& kv : tensorrtParams) { |
| // passing the ndarray address into TRT node attributes to get the weight |
| uint64_t address = reinterpret_cast<uint64_t>(kv.second.GetHandle()); |
| new_attrs[kv.first] = std::to_string(address); |
| } |
| if (!new_attrs.empty()) { |
| internals[i].SetAttributes(new_attrs); |
| internals[i].SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER, |
| new_params_names.substr(0, new_params_names.length() - 1)); |
| } |
| } |
| } |
| } |
| |
| } // namespace contrib |
| } // namespace cpp |
| } // namespace mxnet |
| |
| #endif // MXNET_CPP_CONTRIB_H_ |