blob: 46249c9bbcc6f5a0428a9be2a0b8467631c13474 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file custom.cc
* \brief
* \author Junyuan Xie
*/
#include "./custom-inl.h"
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include "../elemwise_op_common.h"
#include "../operator_common.h"
namespace mxnet {
namespace op {
namespace custom {
struct CustomParam {
std::string op_type;
size_t num_args, num_outs, num_auxs;
std::vector<int> bwd_idx;
std::shared_ptr<MXCallbackList> info;
};
/*! \brief allocate ndarrays from existing ndarrays
*/
inline void AllocateNDArrayCopy(NDArray** nd,
const std::vector<NDArray>& inputs,
size_t idx, int dev_id) {
std::vector<TBlob> aux;
NDArrayStorageType stype = inputs[idx].storage_type();
switch (stype) {
case kUndefinedStorage:
case kDefaultStorage:
*nd = new NDArray(inputs[idx].data(), dev_id);
break;
case kRowSparseStorage:
aux.push_back(inputs[idx].aux_data(rowsparse::kIdx));
*nd = new NDArray(stype, inputs[idx].shape(), inputs[idx].data(), aux,
dev_id);
break;
case kCSRStorage:
aux.push_back(inputs[idx].aux_data(csr::kIndPtr));
aux.push_back(inputs[idx].aux_data(csr::kIdx));
*nd = new NDArray(stype, inputs[idx].shape(), inputs[idx].data(), aux,
dev_id);
break;
}
}
template<CustomOpPropCallbacks Type>
std::vector<std::string> List(const NodeAttrs& attrs) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
char ** args = nullptr;
CHECK(reinterpret_cast<CustomOpListFunc>(
params.info->callbacks[Type])(
&args, params.info->contexts[Type]));
std::vector<std::string> ret;
for (int i = 0; args[i] != nullptr; ++i) {
ret.emplace_back(args[i]);
}
return ret;
}
void AttrParser(NodeAttrs* attrs) {
attrs->parsed = CustomParam();
CustomParam& params = nnvm::get<CustomParam>(attrs->parsed);
std::vector<const char*> keys, vals;
for (auto& p : attrs->dict) {
if (p.first == "op_type") {
params.op_type = p.second;
} else {
keys.push_back(p.first.c_str());
vals.push_back(p.second.c_str());
}
}
CHECK(!params.op_type.empty()) << "Required argument `op_type` is missing.";
CustomOpPropCreator creator = CustomOperator::Get()->Find(params.op_type);
CHECK(CustomOperator::Get()->Find(params.op_type) != nullptr)
<< "Cannot find custom operator " << params.op_type;
params.info.reset(new MXCallbackList, [](MXCallbackList* ptr){
reinterpret_cast<CustomOpDelFunc>(ptr->callbacks[kCustomOpPropDelete])(
ptr->contexts[kCustomOpPropDelete]);
delete ptr;
});
CHECK(creator(params.op_type.c_str(), keys.size(), keys.data(),
vals.data(), params.info.get()));
params.num_args = List<kCustomOpPropListArguments>(*attrs).size();
params.num_outs = List<kCustomOpPropListOutputs>(*attrs).size();
params.num_auxs = List<kCustomOpPropListAuxiliaryStates>(*attrs).size();
int num_dep, *rdeps, counter = 0;
std::vector<int> out_grad, in_data, out_data;
for (size_t i = 0; i < params.num_outs; ++i) out_grad.push_back(counter++);
for (size_t i = 0; i < params.num_args; ++i) in_data.push_back(counter++);
for (size_t i = 0; i < params.num_outs; ++i) out_data.push_back(counter++);
CHECK(reinterpret_cast<CustomOpBwdDepFunc>(
params.info->callbacks[kCustomOpPropDeclareBackwardDependency])(
out_grad.data(), in_data.data(), out_data.data(), &num_dep,
&rdeps, params.info->contexts[kCustomOpPropDeclareBackwardDependency]));
params.bwd_idx.insert(params.bwd_idx.end(), rdeps, rdeps+num_dep);
}
bool InferShape(const NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
size_t total = params.num_args + params.num_outs + params.num_auxs;
std::vector<uint32_t*> shapes(total);
std::vector<int> ndims(total);
size_t buff_size = 0;
for (const auto& i : *in_shape) buff_size += i.ndim();
std::vector<uint32_t> buff(buff_size);
uint32_t *ptr = buff.data();
for (size_t i = 0; i < in_shape->size(); ++i) {
shapes[i] = ptr;
ndims[i] = (*in_shape)[i].ndim();
for (size_t j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) {
*ptr = static_cast<uint32_t>((*in_shape)[i][j]);
}
}
CHECK(reinterpret_cast<CustomOpInferShapeFunc>(
params.info->callbacks[kCustomOpPropInferShape])(
shapes.size(), ndims.data(), shapes.data(),
params.info->contexts[kCustomOpPropInferShape]));
for (size_t i = 0; i < params.num_args; ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(shapes[i], shapes[i]+ndims[i]));
}
size_t base = params.num_args;
for (size_t i = 0; i < params.num_outs; ++i) {
SHAPE_ASSIGN_CHECK(*out_shape, i,
mxnet::TShape(shapes[base+i], shapes[base+i]+ndims[base+i]));
}
base = params.num_args + params.num_outs;
for (size_t i = 0; i < params.num_auxs; ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, params.num_args+i,
mxnet::TShape(shapes[base+i], shapes[base+i]+ndims[base+i]));
}
return true;
}
bool InferType(const NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
if (params.info->num_callbacks <= kCustomOpPropInferType) {
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_type, out_type, -1);
}
std::vector<int> types;
types.reserve(params.num_args + params.num_outs + params.num_auxs);
for (size_t i = 0; i < params.num_args; ++i) {
types.push_back((*in_type)[i]);
}
for (const auto& i : *out_type) {
types.push_back(i);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
types.push_back((*in_type)[params.num_args+i]);
}
CHECK(reinterpret_cast<CustomOpInferTypeFunc>(
params.info->callbacks[kCustomOpPropInferType])(
types.size(), types.data(), params.info->contexts[kCustomOpPropInferType]));
for (size_t i = 0; i < params.num_args; ++i) {
TYPE_ASSIGN_CHECK(*in_type, i, types[i]);
}
for (size_t i = 0; i < params.num_outs; ++i) {
TYPE_ASSIGN_CHECK(*out_type, i, types[params.num_args+i]);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
TYPE_ASSIGN_CHECK(*in_type, params.num_args+i,
types[params.num_args+params.num_outs+i]);
}
return true;
}
std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& out_grads) {
const CustomParam& params = nnvm::get<CustomParam>(n->attrs.parsed);
nnvm::NodePtr g = nnvm::Node::Create();
g->attrs.op = nnvm::Op::Get("_backward_Custom");
g->attrs.name = n->attrs.name;
g->attrs.parsed = params;
g->control_deps.emplace_back(n);
g->inputs.reserve(params.bwd_idx.size());
for (const int& t : params.bwd_idx) {
size_t i = static_cast<size_t>(t);
if (i >= params.num_outs + params.num_args) {
uint32_t idx = static_cast<uint32_t>(i-params.num_outs-params.num_args);
g->inputs.push_back(nnvm::NodeEntry{n, idx, 0});
} else if (i >= params.num_outs) {
g->inputs.push_back(n->inputs[i-params.num_outs]);
} else {
g->inputs.push_back(out_grads[i]);
}
}
for (size_t i = 0; i < params.num_auxs; ++i) {
g->inputs.push_back(n->inputs[i+params.num_args]);
}
std::vector<nnvm::NodeEntry> ret;
for (size_t i = 0; i < params.num_args; ++i) {
ret.emplace_back(nnvm::NodeEntry{g, static_cast<uint32_t>(i), 0});
}
if (params.num_auxs) {
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = nnvm::Op::Get("_NoGradient");
ng->attrs.name = "NoGradient";
for (size_t i = 0; i < params.num_auxs; ++i) {
ret.emplace_back(nnvm::NodeEntry{ng, 0, 0});
}
}
return ret;
}
OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
const mxnet::ShapeVector& in_shape,
const std::vector<int>& in_type) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
std::vector<uint32_t*> shapes(in_shape.size());
std::vector<int> ndims(in_shape.size());
size_t buff_size = 0;
for (const auto& i : in_shape) buff_size += i.ndim();
std::vector<uint32_t> buff(buff_size);
uint32_t *ptr = buff.data();
for (size_t i = 0; i < in_shape.size(); ++i) {
shapes[i] = ptr;
ndims[i] = in_shape[i].ndim();
for (size_t j = 0; j < in_shape[i].ndim(); ++j, ++ptr) {
*ptr = static_cast<uint32_t>(in_shape[i][j]);
}
}
std::ostringstream os;
os << ctx;
MXCallbackList *op_info = new MXCallbackList;
CHECK(reinterpret_cast<CustomOpCreateFunc>(
params.info->callbacks[kCustomOpPropCreateOperator])(
os.str().c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(),
op_info, params.info->contexts[kCustomOpPropCreateOperator]));
CustomParam state = params;
state.info.reset(op_info, [](MXCallbackList *ptr){
reinterpret_cast<CustomOpDelFunc>(ptr->callbacks[kCustomOpDelete])(
ptr->contexts[kCustomOpDelete]);
delete ptr;
});
return OpStatePtr::Create<CustomParam>(state);
}
void ForwardEx(const OpStatePtr& state, const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const CustomParam& params = state.get_state<CustomParam>();
std::vector<void*> ptrs;
// Tags are provided to the callback to provide the frontend
std::vector<int> tags;
std::vector<NDArray> cpys;
// info on what ndarray is at each position in the input and output vector
// 0 - Input
// 1 - Output
// 4 - aux
std::unordered_set<int> input_tags({0, 4});
std::unordered_set<int> output_tags({1});
auto dev_id = ctx.run_ctx.ctx.dev_id;
for (size_t i = 0; i < params.num_args; ++i) {
NDArray* nd;
AllocateNDArrayCopy(&nd, inputs, i, dev_id);
cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(0);
}
for (size_t i = 0; i < params.num_outs; ++i) {
NDArray* nd;
AllocateNDArrayCopy(&nd, outputs, i, dev_id);
cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(1);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
size_t idx = i + params.num_args;
NDArray* nd;
AllocateNDArrayCopy(&nd, inputs, idx, dev_id);
cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(4);
}
CustomOperator::Get()->Push(
[=]() {
CHECK(reinterpret_cast<CustomOpFBFunc>(
params.info->callbacks[kCustomOpForward])(
ptrs.size(), const_cast<void**>(ptrs.data()),
const_cast<int*>(tags.data()),
reinterpret_cast<const int*>(req.data()),
static_cast<int>(ctx.is_train),
params.info->contexts[kCustomOpForward]));
},
ctx, false, ctx.is_train, cpys, tags, output_tags, outputs);
}
void BackwardEx(const OpStatePtr& state, const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const CustomParam& params = state.get_state<CustomParam>();
size_t total = 2 * params.num_args + 2 * params.num_outs + params.num_auxs;
std::vector<void*> ptrs(params.num_args + 2 * params.num_outs, nullptr);
std::vector<int> tags;
std::vector<NDArray> cpys;
ptrs.reserve(total);
tags.reserve(total);
cpys.reserve(total);
// info on what ndarray is at each position in the input and output vector
// 3 - out grads
// 0 - inputs
// 1 - outputs
// 4 - auxs
// 2 - in grads
std::unordered_set<int> input_tags({3, 0, 1, 4});
std::unordered_set<int> output_tags({2});
for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(3);
for (size_t i = 0; i < params.num_args; ++i) tags.push_back(0);
for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(1);
auto dev_id = ctx.run_ctx.ctx.dev_id;
for (size_t i = 0; i < params.bwd_idx.size(); ++i) {
NDArray* nd;
AllocateNDArrayCopy(&nd, inputs, i, dev_id);
cpys.push_back(*nd);
ptrs[params.bwd_idx[i]] = reinterpret_cast<void*>(nd);
}
for (auto& ptr : ptrs) {
NDArray* nd;
if (ptr == nullptr) {
nd = new NDArray();
ptr = reinterpret_cast<void*>(nd);
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
NDArray* nd;
AllocateNDArrayCopy(&nd, outputs, i, dev_id);
cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(2);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
size_t idx = inputs.size() - params.num_auxs + i;
NDArray* nd;
AllocateNDArrayCopy(&nd, inputs, idx, dev_id);
cpys.push_back(*nd);
ptrs.push_back(reinterpret_cast<void*>(nd));
tags.push_back(4);
}
CustomOperator::Get()->Push(
[=]() {
CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
params.info->contexts[kCustomOpBackward]));
}, ctx, false, ctx.is_train, cpys, tags, output_tags, outputs);
}
// infer storage backward function for custom op which assigns kDefaultStorage for
// all undefined stypes and dispatches on DispatchMode::kFComputeEx.
inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* iattr,
std::vector<int>* oattr) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
if (params.info->num_callbacks <= kCustomOpPropBackwardInferStorageType) {
for (size_t i = 0; i < iattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
}
for (size_t i = 0; i < oattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}
size_t total = 2 * params.num_args + 2 * params.num_outs + params.num_auxs;
size_t bwd_deps_size = params.bwd_idx.size();
std::vector<int> stypes(bwd_deps_size, -1);
std::vector<int> tags;
stypes.reserve(total);
tags.reserve(total);
for (size_t i = 0; i < bwd_deps_size; i++) {
if (params.bwd_idx[i] < static_cast<int>(params.num_outs))
tags.push_back(3);
else if (params.bwd_idx[i] <
static_cast<int>(params.num_outs + params.num_args))
tags.push_back(0);
else
tags.push_back(1);
stypes[i] = (*iattr)[i];
}
for (int i : *oattr) {
stypes.push_back(i);
tags.push_back(2);
}
for (size_t i = (iattr->size() - params.num_auxs); i < iattr->size(); i++) {
stypes.push_back((*iattr)[i]);
tags.push_back(4);
}
CHECK(reinterpret_cast<CustomOpBackwardInferStorageTypeFunc>(
params.info->callbacks[kCustomOpPropBackwardInferStorageType])(
stypes.size(), stypes.data(), tags.data(),
params.info->contexts[kCustomOpPropBackwardInferStorageType]));
for (size_t i = 0; i < bwd_deps_size; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
}
for (size_t i = 0; i < oattr->size(); ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, stypes[i + bwd_deps_size]);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, (i + iattr->size() - params.num_auxs),
stypes[i + params.num_outs + bwd_deps_size]);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}
// infer storage function for custom op which assigns kDefaultStorage for
// all undefined stypes and dispatches on DispatchMode::kFComputeEx.
inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* iattr, std::vector<int>* oattr) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
if (params.info->num_callbacks <= kCustomOpPropInferStorageType) {
for (size_t i = 0; i < iattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
}
for (size_t i = 0; i < oattr->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}
std::vector<int> stypes;
stypes.reserve(params.num_args + params.num_outs + params.num_auxs);
for (size_t i = 0; i < params.num_args; ++i) {
stypes.push_back((*iattr)[i]);
}
for (const auto& i : *oattr) {
stypes.push_back(i);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
stypes.push_back((*iattr)[params.num_args + i]);
}
CHECK(reinterpret_cast<CustomOpInferStorageTypeFunc>(
params.info->callbacks[kCustomOpPropInferStorageType])(
stypes.size(), stypes.data(),
params.info->contexts[kCustomOpPropInferStorageType]));
for (size_t i = 0; i < params.num_args; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
}
for (size_t i = 0; i < params.num_outs; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, stypes[params.num_args + i]);
}
for (size_t i = 0; i < params.num_auxs; ++i) {
STORAGE_TYPE_ASSIGN_CHECK(*iattr, params.num_args + i,
stypes[params.num_args + params.num_outs + i]);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}
NNVM_REGISTER_OP(Custom)
.describe(R"code(Apply a custom operator implemented in a frontend language (like Python).
Custom operators should override required methods like `forward` and `backward`.
The custom operator must be registered before it can be used.
Please check the tutorial here: http://mxnet.io/faq/new_op.html.
)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs){
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
return params.num_args + params.num_auxs;
})
.set_num_outputs([](const NodeAttrs& attrs){
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
return params.num_outs;
})
.set_attr_parser(AttrParser)
.set_attr<mxnet::FInferShape>("FInferShape", InferShape)
.set_attr<nnvm::FInferType>("FInferType", InferType)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
std::vector<std::string> args = List<kCustomOpPropListArguments>(attrs);
std::vector<std::string> auxs = List<kCustomOpPropListAuxiliaryStates>(attrs);
args.insert(args.end(), auxs.begin(), auxs.end());
return args;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", List<kCustomOpPropListOutputs>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
std::vector<uint32_t> ret;
for (size_t i = 0; i < params.num_auxs; ++i) ret.push_back(i+params.num_args);
return ret;
})
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
return ExecType::kAsync;
})
.set_attr<nnvm::FGradient>("FGradient", Gradient)
.set_attr<FCreateOpState>("FCreateOpState", CreateState)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForwardEx)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForwardEx)
.set_attr<FInferStorageType>("FInferStorageType", InferStorageType)
.add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.")
.add_argument("op_type", "string", "Name of the custom operator. "
"This is the name that is passed to `mx.operator.register` "
"to register the operator.");
NNVM_REGISTER_OP(_backward_Custom)
.set_num_inputs([](const NodeAttrs& attrs){
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
return params.bwd_idx.size();
})
.set_num_outputs([](const NodeAttrs& attrs){
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
return params.num_args;
})
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true)
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
return ExecType::kAsync;
})
.set_attr<FInferStorageType>("FInferStorageType", BackwardInferStorageType)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", BackwardEx)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", BackwardEx);
} // namespace custom
} // namespace op
} // namespace mxnet