blob: 66dd59b43b327d46ad4e4a7a6baf03f75d673526 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file native_op-inl.h
* \brief
* \author Junyuan Xie
*/
#ifndef MXNET_OPERATOR_CUSTOM_INL_H_
#define MXNET_OPERATOR_CUSTOM_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/c_api.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include <sstream>
#include <thread>
#include <mutex>
#include <functional>
#include <condition_variable>
#include <queue>
#include "./operator_common.h"
namespace mxnet {
namespace op {
struct CustomOpParam {
std::string op_type;
std::vector<std::pair<std::string, std::string> > kwargs;
};
template<typename xpu>
class CustomOp : public Operator {
public:
explicit CustomOp(CustomOpInfo* op_info) {
op_info_.reset(op_info, [](CustomOpInfo *ptr){
ptr->del(ptr->p_del);
delete ptr;
});
if (std::string("NaiveEngine") == dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
sync_mode_ = true;
} else {
sync_mode_ = false;
destructing_ = false;
worker_ = std::thread([&]() {
std::unique_lock<std::mutex> lock(mtx_);
while (!q_.empty() || !destructing_) {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
q_.front()();
q_.pop();
}
}
});
}
}
~CustomOp() {
if (!sync_mode_) {
{
std::unique_lock<std::mutex> lock(mtx_);
destructing_ = true;
cv_.notify_all();
}
worker_.join();
}
}
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args);
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args);
virtual ExecType exec_type() const {
return kAsync;
}
private:
Context get_ctx();
std::shared_ptr<CustomOpInfo> op_info_;
std::mutex mtx_;
std::condition_variable cv_;
std::thread worker_;
std::queue<std::function<void(void)> > q_;
bool destructing_;
bool sync_mode_;
}; // CustomOp
template<typename xpu>
Operator* CreateOp(CustomOpInfo *op_info);
class CustomOpProp : public OperatorProperty {
public:
static void Register(const std::string &op_type, CustomOpPropCreator creator) {
if (registry_.find(op_type) != registry_.end()) {
LOG(WARNING) << "New registration is overriding existing custom operator " << op_type;
}
registry_[op_type] = creator;
}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
kwargs_ = kwargs;
param_.op_type = "";
param_.kwargs.clear();
std::vector<const char*> keys, vals;
for (auto &p : kwargs) {
if (p.first == "op_type") {
param_.op_type = p.second;
} else {
param_.kwargs.push_back(p);
keys.push_back(p.first.c_str());
vals.push_back(p.second.c_str());
}
}
CHECK_NE(param_.op_type, "") << "Custom operator type missing";
CHECK(registry_.find(param_.op_type) != registry_.end())
<< "Cannot find custom operator type " << param_.op_type;
CustomOpPropCreator creator = registry_[param_.op_type];
info_.reset(new CustomOpPropInfo, [](CustomOpPropInfo* ptr){ptr->del(ptr->p_del);});
CHECK(creator(param_.op_type.c_str(), keys.size(), keys.data(), vals.data(), info_.get()));
num_inputs_ = ListArguments().size();
num_outputs_ = ListOutputs().size();
num_auxs_ = ListAuxiliaryStates().size();
}
std::vector<std::string> ListArguments() const override {
char ** args = NULL;
CHECK(info_->list_arguments(&args, info_->p_list_arguments));
std::vector<std::string> ret;
for (int i = 0; args[i] != NULL; ++i) {
ret.push_back(args[i]);
}
return ret;
}
std::vector<std::string> ListOutputs() const override {
char ** args = NULL;
CHECK(info_->list_outputs(&args, info_->p_list_outputs));
std::vector<std::string> ret;
for (int i = 0; args[i] != NULL; ++i) {
ret.push_back(args[i]);
}
return ret;
}
std::vector<std::string> ListAuxiliaryStates() const override {
char ** args = NULL;
CHECK(info_->list_auxiliary_states(&args, info_->p_list_auxiliary_states));
std::vector<std::string> ret;
for (int i = 0; args[i] != NULL; ++i) {
ret.push_back(args[i]);
}
return ret;
}
int NumOutputs() const override {
return ListOutputs().size();
}
std::map<std::string, std::string> GetParams() const override {
return std::map<std::string, std::string>(kwargs_.begin(), kwargs_.end());
}
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
std::vector<unsigned*> shapes;
std::vector<int> ndims;
for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) {
shapes.push_back(iter->data());
ndims.push_back(iter->ndim());
}
shapes.resize(num_inputs_+num_outputs_+num_auxs_);
ndims.resize(num_inputs_+num_outputs_+num_auxs_);
CHECK(info_->infer_shape(shapes.size(), ndims.data(), shapes.data(), info_->p_infer_shape));
for (unsigned i = 0; i < in_shape->size(); ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, i, TShape(shapes[i], shapes[i]+ndims[i]));
}
out_shape->clear();
for (unsigned i = num_inputs_; i < num_inputs_+num_outputs_; ++i) {
out_shape->push_back(TShape(shapes[i], shapes[i]+ndims[i]));
}
aux_shape->clear();
for (unsigned i = num_inputs_+num_outputs_; i < shapes.size(); ++i) {
aux_shape->push_back(TShape(shapes[i], shapes[i]+ndims[i]));
}
return true;
}
OperatorProperty* Copy() const override {
CustomOpProp *prop_sym = new CustomOpProp();
prop_sym->Init(kwargs_);
return prop_sym;
}
std::string TypeString() const override {
return "Custom";
}
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
int num_dep;
int *rdeps;
CHECK(info_->declare_backward_dependency(out_grad.data(), in_data.data(),
out_data.data(), &num_dep, &rdeps,
info_->p_declare_backward_dependency));
std::vector<int> deps;
deps.insert(deps.end(), rdeps, rdeps+num_dep);
return deps;
}
std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const override {
return {};
}
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;
private:
static std::map<std::string, CustomOpPropCreator> registry_;
CustomOpParam param_;
std::shared_ptr<CustomOpPropInfo> info_;
std::vector<std::pair<std::string, std::string> > kwargs_;
unsigned num_inputs_, num_outputs_, num_auxs_;
}; // class CustomOpProp
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CUSTOM_INL_H_