blob: 007e8b1241a3df331aa8e18f958cc0d1542795e1 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file custom.cc
* \brief
* \author Junyuan Xie
*/
#include "./custom-inl.h"
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
namespace mxnet {
namespace op {
std::map<std::string, CustomOpPropCreator> CustomOpProp::registry_;
template<>
Context CustomOp<cpu>::get_ctx() {
return Context::CPU();
}
template<>
Operator *CreateOp<cpu>(CustomOpInfo *op_info) {
return new CustomOp<cpu>(op_info);
}
#if MXNET_USE_CUDA
template<>
Context CustomOp<gpu>::get_ctx() {
int dev_id;
CHECK_EQ(cudaGetDevice(&dev_id), cudaSuccess);
return Context::GPU(dev_id);
}
template<>
Operator* CreateOp<gpu>(CustomOpInfo *op_info) {
return new CustomOp<gpu>(op_info);
}
#endif // MXNET_USE_CUDA
template<typename xpu>
void CustomOp<xpu>::Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
Context ndctx = get_ctx();
std::vector<void*> ptrs;
std::vector<NDArray> ndcpy;
std::vector<Engine::VarHandle> ndvar;
std::vector<int> tags;
std::vector<int> reqs(req.begin(), req.end());
for (auto& blob : in_data) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(0);
}
for (auto& blob : out_data) {
NDArray* nd = new NDArray(blob, ndctx.dev_id);
ptrs.push_back(reinterpret_cast<void*>(nd));
ndcpy.push_back(*nd);
ndvar.push_back(nd->var());
tags.push_back(1);
}
for (auto& blob : aux_args) {
NDArray* nd = new NDArray(blob, ndctx.dev_id);
ptrs.push_back(reinterpret_cast<void*>(nd));
ndcpy.push_back(*nd);
ndvar.push_back(nd->var());
tags.push_back(4);
}
std::sort(ndvar.begin(), ndvar.end());
ndvar.resize(std::unique(ndvar.begin(), ndvar.end()) - ndvar.begin());
auto compute = [=]() mutable {
CHECK(
op_info_->forward(ptrs.size(),
ptrs.data(), tags.data(),
reqs.data(),
ctx.is_train,
op_info_->p_forward));
// NDArray* in ptrs is freed by frontend side. We keep a copy in ndcpy to keep ndvar alive
Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx) {
ctx.async_on_complete();
}, ndctx, ndvar, {},
FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOpForward"));
};
if (sync_mode_) {
compute();
} else {
std::unique_lock<std::mutex> lock(mtx_);
q_.push(compute);
cv_.notify_all();
}
}
template<typename xpu>
void CustomOp<xpu>::Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
Context ndctx = get_ctx();
std::vector<void*> ptrs;
std::vector<NDArray> ndcpy;
std::vector<Engine::VarHandle> ndvar;
std::vector<int> tags;
std::vector<int> reqs(req.begin(), req.end());
for (auto& blob : in_data) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(0);
}
for (auto& blob : out_data) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(1);
}
for (auto& blob : in_grad) {
NDArray* nd = new NDArray(blob, ndctx.dev_id);
ptrs.push_back(reinterpret_cast<void*>(nd));
ndcpy.push_back(*nd);
ndvar.push_back(nd->var());
tags.push_back(2);
}
for (auto& blob : aux_args) {
NDArray* nd = new NDArray(blob, ndctx.dev_id);
ptrs.push_back(reinterpret_cast<void*>(nd));
ndcpy.push_back(*nd);
ndvar.push_back(nd->var());
tags.push_back(4);
}
std::sort(ndvar.begin(), ndvar.end());
ndvar.resize(std::unique(ndvar.begin(), ndvar.end()) - ndvar.begin());
for (auto& blob : out_grad) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(3);
}
auto compute = [=]() mutable {
CHECK(
op_info_->backward(ptrs.size(),
ptrs.data(),
tags.data(),
reqs.data(),
true,
op_info_->p_backward));
// NDArray* in ptrs is freed by frontend side. We keep a copy in ndcpy to keep ndvar alive
Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx){
ctx.async_on_complete();
}, ndctx, ndvar, {},
FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOpBackward"));
};
if (sync_mode_) {
compute();
} else {
std::unique_lock<std::mutex> lock(mtx_);
q_.push(compute);
cv_.notify_all();
}
}
Operator* CustomOpProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<unsigned*> shapes;
std::vector<int> ndims;
for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) {
shapes.push_back(iter->data());
ndims.push_back(iter->ndim());
}
std::string str_ctx;
if (ctx.dev_mask() == cpu::kDevMask) {
str_ctx = "cpu";
} else {
str_ctx = "gpu";
}
CustomOpInfo *op_info = new CustomOpInfo;
CHECK(info_->create_operator(str_ctx.c_str(), shapes.size(), shapes.data(),
ndims.data(), in_type->data(), op_info, info_->p_create_operator));
DO_BIND_DISPATCH(CreateOp, op_info);
}
MXNET_REGISTER_OP_PROPERTY(Custom, CustomOpProp)
.describe("Custom operator implemented in frontend.")
.add_argument("op_type", "string", "Type of custom operator. Must be registered first.");
} // namespace op
} // namespace mxnet