blob: 5feefd32efdf9c9a1d93dd49653146d4454bc60e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file contrib.h
* \brief utility function to enable some contrib features
* \author Haohuan Wang
*/
#ifndef MXNET_CPP_CONTRIB_H_
#define MXNET_CPP_CONTRIB_H_
#include <iostream>
#include <string>
#include <map>
#include <vector>
#include "mxnet-cpp/symbol.h"
namespace mxnet {
namespace cpp {
namespace details {
/*!
* split a string with the given delimiter
* @param str string to be parsed
* @param delimiter delimiter
* @return delimited list of string
*/
inline std::vector<std::string> split(const std::string& str, const std::string& delimiter) {
std::vector<std::string> splitted;
size_t last = 0;
size_t next = 0;
while ((next = str.find(delimiter, last)) != std::string::npos) {
splitted.push_back(str.substr(last, next - last));
last = next + 1;
}
splitted.push_back(str.substr(last));
return splitted;
}
} // namespace details
namespace contrib {
// needs to be same with
// https://github.com/apache/mxnet/blob/1c874cfc807cee755c38f6486e8e0f4d94416cd8/src/operator/subgraph/tensorrt/tensorrt-inl.h#L190
static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER = "subgraph_params_names"; // NOLINT
// needs to be same with
// https://github.com/apache/mxnet/blob/master/src/operator/subgraph/tensorrt/tensorrt.cc#L244
static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX = "subgraph_param_"; // NOLINT
/*!
* this is a mimic to
* https://github.com/apache/mxnet/blob/master/python/mxnet/contrib/tensorrt.py#L37
* @param symbol symbol that already called subgraph api
* @param argParams original arg params, params needed by tensorrt will be removed after calling
* this function
* @param auxParams original aux params, params needed by tensorrt will be removed after calling
* this function
*/
inline void InitTensorRTParams(const mxnet::cpp::Symbol& symbol,
std::map<std::string, mxnet::cpp::NDArray>* argParams,
std::map<std::string, mxnet::cpp::NDArray>* auxParams) {
mxnet::cpp::Symbol internals = symbol.GetInternals();
mx_uint numSymbol = internals.GetNumOutputs();
for (mx_uint i = 0; i < numSymbol; ++i) {
std::map<std::string, std::string> attrs = internals[i].ListAttributes();
if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
std::string new_params_names;
std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
std::vector<std::string> keys =
details::split(attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER], ";");
for (const auto& key : keys) {
if (argParams->find(key) != argParams->end()) {
new_params_names += key + ";";
tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key];
argParams->erase(key);
} else if (auxParams->find(key) != auxParams->end()) {
new_params_names += key + ";";
tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key];
auxParams->erase(key);
}
}
std::map<std::string, std::string> new_attrs = {};
for (const auto& kv : tensorrtParams) {
// passing the ndarray address into TRT node attributes to get the weight
uint64_t address = reinterpret_cast<uint64_t>(kv.second.GetHandle());
new_attrs[kv.first] = std::to_string(address);
}
if (!new_attrs.empty()) {
internals[i].SetAttributes(new_attrs);
internals[i].SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
new_params_names.substr(0, new_params_names.length() - 1));
}
}
}
}
} // namespace contrib
} // namespace cpp
} // namespace mxnet
#endif // MXNET_CPP_CONTRIB_H_