blob: 8a421d7b6b4f0b8b3489ea5fe2fc87777599a321 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file operator.hpp
* \brief implementation of operator
* \author Chuntao Hong, Zhang Chen
*/
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_HPP_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_HPP_
#include <algorithm>
#include <string>
#include <vector>
#include <iterator>
#include "mxnet-cpp/base.h"
#include "mxnet-cpp/op_map.h"
#include "mxnet-cpp/operator.h"
namespace mxnet {
namespace cpp {
/*
* Pushing NDArray or Symbol as inputs here to avoid partial specialization
* like PushInput<NDArray, Args..., N>, which is not allowed in C++
*/
template <>
inline Operator& Operator::SetParam<NDArray>(int pos, const NDArray &value) {
input_ndarrays_.push_back(value.GetHandle());
return *this;
}
template <>
inline Operator& Operator::SetParam<Symbol>(int pos, const Symbol &value) {
input_symbols_.push_back(value.GetHandle());
return *this;
}
inline OpMap*& Operator::op_map() {
static OpMap *op_map_ = new OpMap();
return op_map_;
}
inline Operator::Operator(const std::string &operator_name) {
handle_ = op_map()->GetSymbolCreator(operator_name);
const char *name;
const char *description;
mx_uint num_args;
const char **arg_names;
const char **arg_type_infos;
const char **arg_descriptions;
const char *key_var_num_args;
MXSymbolGetAtomicSymbolInfo(handle_,
&name,
&description,
&num_args,
&arg_names,
&arg_type_infos,
&arg_descriptions,
&key_var_num_args);
for (mx_uint i = 0; i < num_args; ++i) {
arg_names_.push_back(arg_names[i]);
}
}
inline Symbol Operator::CreateSymbol(const std::string &name) {
if (input_keys_.size() > 0) {
CHECK_EQ(input_keys_.size(), input_symbols_.size());
}
const char *pname = name == "" ? nullptr : name.c_str();
SymbolHandle symbol_handle;
std::vector<const char *> input_keys;
std::vector<const char *> param_keys;
std::vector<const char *> param_values;
for (auto &data : params_) {
param_keys.push_back(data.first.c_str());
param_values.push_back(data.second.c_str());
}
for (auto &data : this->input_keys_) {
input_keys.push_back(data.c_str());
}
const char **input_keys_p =
(input_keys.size() > 0) ? input_keys.data() : nullptr;
MXSymbolCreateAtomicSymbol(handle_, param_keys.size(), param_keys.data(),
param_values.data(), &symbol_handle);
MXSymbolCompose(symbol_handle, pname, input_symbols_.size(), input_keys_p,
input_symbols_.data());
return Symbol(symbol_handle);
}
inline void Operator::Invoke(std::vector<NDArray> &outputs) {
if (input_keys_.size() > 0) {
CHECK_EQ(input_keys_.size(), input_ndarrays_.size());
}
std::vector<const char *> input_keys;
std::vector<const char *> param_keys;
std::vector<const char *> param_values;
for (auto &data : params_) {
param_keys.push_back(data.first.c_str());
param_values.push_back(data.second.c_str());
}
int num_inputs = input_ndarrays_.size();
int num_outputs = outputs.size();
std::vector<NDArrayHandle> output_handles;
std::transform(outputs.begin(), outputs.end(),
std::back_inserter(output_handles), [](NDArray& a) {
return a.GetHandle();
});
NDArrayHandle *outputs_receiver = nullptr;
if (num_outputs > 0) {
outputs_receiver = output_handles.data();
}
MXImperativeInvoke(handle_, num_inputs, input_ndarrays_.data(),
&num_outputs, &outputs_receiver,
param_keys.size(), param_keys.data(), param_values.data());
if (outputs.size() > 0)
return;
std::transform(outputs_receiver, outputs_receiver+num_outputs,
std::back_inserter(outputs), [](const NDArrayHandle& handle) {
return NDArray(handle);
});
}
inline std::vector<NDArray> Operator::Invoke() {
std::vector<NDArray> outputs;
Invoke(outputs);
return outputs;
}
inline void Operator::Invoke(NDArray &output) {
std::vector<NDArray> outputs{output};
Invoke(outputs);
}
inline Operator &Operator::SetInput(const std::string &name, Symbol symbol) {
input_keys_.push_back(name.c_str());
input_symbols_.push_back(symbol.GetHandle());
return *this;
}
inline Operator &Operator::SetInput(const std::string &name, NDArray ndarray) {
input_keys_.push_back(name.c_str());
input_ndarrays_.push_back(ndarray.GetHandle());
return *this;
}
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_OPERATOR_HPP_