blob: 773fe7753930f9c4e2a1089a72ed84f59ec70b09 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file ndarray_op.cc
* \brief
* \author Junyuan Xie
*/
#include "./ndarray_op-inl.h"
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
namespace mxnet {
namespace op {
template<>
Context NDArrayOp<cpu>::get_ctx() {
return Context::CPU();
}
template<>
Operator *CreateOp<cpu>(NDArrayOpParam param) {
return new NDArrayOp<cpu>(param);
}
#if MXNET_USE_CUDA
template<>
Context NDArrayOp<gpu>::get_ctx() {
int dev_id;
CHECK_EQ(cudaGetDevice(&dev_id), cudaSuccess);
return Context::GPU(dev_id);
}
template<>
Operator* CreateOp<gpu>(NDArrayOpParam param) {
return new NDArrayOp<gpu>(param);
}
#endif // MXNET_USE_CUDA
template<typename xpu>
void NDArrayOp<xpu>::Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
Context ndctx = get_ctx();
std::vector<void*> ptrs;
std::vector<Engine::VarHandle> ndvar;
std::vector<int> tags;
for (auto& i : req) CHECK_NE(i, kAddTo);
for (auto& blob : in_data) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(0);
}
for (auto& blob : out_data) {
NDArray* nd = new NDArray(blob, ndctx.dev_id);
ptrs.push_back(reinterpret_cast<void*>(nd));
ndvar.push_back(nd->var());
tags.push_back(1);
}
std::sort(ndvar.begin(), ndvar.end());
ndvar.resize(std::unique(ndvar.begin(), ndvar.end()) - ndvar.begin());
std::vector<NDArray> ndcpy;
for (auto& i : ptrs) {
ndcpy.push_back(*reinterpret_cast<NDArray*>(i));
}
CHECK(param_.pinfo->forward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_forward));
Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx) {ctx.async_on_complete(); },
ndctx, ndvar, {}, FnProperty::kNormal, 0,
PROFILER_MESSAGE("NDArrayOpForward"));
}
template<typename xpu>
void NDArrayOp<xpu>::Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
Context ndctx = get_ctx();
std::vector<void*> ptrs;
std::vector<Engine::VarHandle> ndvar;
std::vector<int> tags;
for (auto& i : req) CHECK_NE(i, kAddTo);
for (auto& blob : in_data) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(0);
}
for (auto& blob : out_data) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(1);
}
for (auto& blob : in_grad) {
NDArray* nd = new NDArray(blob, ndctx.dev_id);
ptrs.push_back(reinterpret_cast<void*>(nd));
ndvar.push_back(nd->var());
tags.push_back(2);
}
std::sort(ndvar.begin(), ndvar.end());
ndvar.resize(std::unique(ndvar.begin(), ndvar.end()) - ndvar.begin());
for (auto& blob : out_grad) {
ptrs.push_back(reinterpret_cast<void*>(new NDArray(blob, ndctx.dev_id)));
tags.push_back(3);
}
std::vector<NDArray> ndcpy;
for (auto& i : ptrs) {
ndcpy.push_back(*reinterpret_cast<NDArray*>(i));
}
CHECK(param_.pinfo->backward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_backward));
Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx){ ctx.async_on_complete(); },
ndctx, ndvar, {}, FnProperty::kNormal, 0,
PROFILER_MESSAGE("NDArrayOpBackward"));
}
Operator* NDArrayOpProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
}
DMLC_REGISTER_PARAMETER(NDArrayOpParam);
MXNET_REGISTER_OP_PROPERTY(_NDArray, NDArrayOpProp)
.describe("Stub for implementing an operator implemented in native frontend language with ndarray.")
.add_arguments(NDArrayOpParam::__FIELDS__());
} // namespace op
} // namespace mxnet