| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file operator_util.cc |
| * Implementation of operator util. |
| */ |
| #include <mxnet/operator_util.h> |
| #include <mxnet/operator.h> |
| #include <mxnet/ndarray.h> |
| #include <mxnet/engine.h> |
| #include <vector> |
| #include <mutex> |
| #include "./operator_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| class SimpleOpPropBase; |
| class SimpleSourceOpProp; |
| class SimpleUnaryOpProp; |
| class SimpleBinaryOpProp; |
| |
| class SimpleOpRegEntryImpl : public SimpleOpRegEntry { |
| public: |
| TSelf& set_symbol_op_name(char const* symbol_name_str) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| std::string symbol_name(symbol_name_str); |
| CHECK(op_reg_ == nullptr || symbol_name == symbol_name_) |
| << " operator " << this->name << " need to call set_symbol_op_name " << symbol_name |
| << "before all other calls"; |
| symbol_name_ = symbol_name; |
| return *this; |
| } |
| |
| TSelf& set_enable_scalar(bool enable_scalar, SimpleOpScalarOption type_mask) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| enable_scalar_ = enable_scalar; |
| scalar_type_mask_ = type_mask; |
| CHECK(!enable_kwargs_ || !enable_scalar_) << "Cannot have both kwargs and scalar arguments"; |
| return *this; |
| } |
| |
| TSelf& set_enable_kwargs(bool enable_kwargs) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| enable_kwargs_ = enable_kwargs; |
| CHECK(!enable_kwargs_ || !enable_scalar_) << "Cannot have both kwargs and scalar arguments"; |
| return *this; |
| } |
| |
| TSelf& set_resource_request(const std::vector<ResourceRequest>& reqs) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| resource_requests_ = reqs; |
| return *this; |
| } |
| |
| TSelf& set_resource_request(ResourceRequest req) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| resource_requests_ = {req}; |
| return *this; |
| } |
| |
| TSelf& set_shape_function(SourceShapeFunction fshapeinfer) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| source_shape_ = fshapeinfer; |
| return *this; |
| } |
| |
| TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| unary_shape_ = fshapeinfer; |
| return *this; |
| } |
| |
| TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| binary_shape_ = fshapeinfer; |
| return *this; |
| } |
| |
| TSelf& set_function(int dev_mask, |
| SourceFunction fsource, |
| SimpleOpRegOption register_symbolic) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&fsource_, dev_mask, fsource, "SourceFunction"); |
| if (++reg_counter_ == 1) { |
| this->RegisterSourceImperative(); |
| register_symbolic_ = (register_symbolic == kRegisterSymbolic); |
| if (register_symbolic_) { |
| this->RegisterSourceSymbolic(); |
| } |
| } |
| return *this; |
| } |
| |
| TSelf& set_function(int dev_mask, |
| UnaryFunction funary, |
| SimpleOpInplaceOption inplace_in_out, |
| SimpleOpRegOption register_symbolic) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&funary_, dev_mask, funary, "UnaryFunction"); |
| unary_forward_inplace_in_out_ = (inplace_in_out == kInplaceInOut); |
| if (++reg_counter_ == 1) { |
| this->RegisterUnaryImperative(); |
| register_symbolic_ = (register_symbolic == kRegisterSymbolic); |
| if (register_symbolic_) { |
| this->RegisterUnarySymbolic(); |
| } |
| } |
| return *this; |
| } |
| |
| TSelf& set_function(int dev_mask, |
| BinaryFunction fbinary, |
| SimpleOpInplaceOption inplace_lhs_out, |
| SimpleOpRegOption register_symbolic) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&fbinary_, dev_mask, fbinary, "BinaryFunction"); |
| binary_forward_inplace_lhs_out_ = (inplace_lhs_out == kInplaceLhsOut); |
| if (++reg_counter_ == 1) { |
| this->RegisterBinaryImperative(); |
| register_symbolic_ = (register_symbolic == kRegisterSymbolic); |
| if (register_symbolic_) { |
| this->RegisterBinarySymbolic(); |
| } |
| } |
| return *this; |
| } |
| |
| TSelf& set_gradient(int dev_mask, |
| UnaryGradFunctionT0 fgrad, |
| SimpleOpInplaceOption inplace_out_in_grad) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&funary_grad_t0_, dev_mask, fgrad, "UnaryGradFunctionT0"); |
| unary_backward_inplace_out_in_ = (inplace_out_in_grad == kInplaceOutIn); |
| return *this; |
| } |
| |
| TSelf& set_gradient(int dev_mask, |
| UnaryGradFunctionT1 fgrad, |
| SimpleOpInplaceOption inplace_out_in_grad) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&funary_grad_t1_, dev_mask, fgrad, "UnaryGradFunctionT1"); |
| unary_backward_inplace_out_in_ = (inplace_out_in_grad == kInplaceOutIn); |
| return *this; |
| } |
| |
| TSelf& set_gradient(int dev_mask, |
| UnaryGradFunctionT2 fgrad, |
| SimpleOpInplaceOption inplace_out_in_grad) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&funary_grad_t2_, dev_mask, fgrad, "UnaryGradFunctionT2"); |
| unary_backward_inplace_out_in_ = (inplace_out_in_grad == kInplaceOutIn); |
| return *this; |
| } |
| |
| TSelf& set_gradient(int dev_mask, |
| BinaryGradFunctionT0 fgrad, |
| SimpleOpInplaceOption inplace_out_lhs_grad) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&fbinary_grad_t0_, dev_mask, fgrad, "BinaryGradFunctionT0"); |
| binary_backward_inplace_out_lhs_ = (inplace_out_lhs_grad == kInplaceLhsOut); |
| return *this; |
| } |
| |
| TSelf& set_gradient(int dev_mask, |
| BinaryGradFunctionT1 fgrad, |
| SimpleOpInplaceOption inplace_out_lhs_grad) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| SetFunction(&fbinary_grad_t1_, dev_mask, fgrad, "BinaryGradFunctionT1"); |
| binary_backward_inplace_out_lhs_ = (inplace_out_lhs_grad == kInplaceLhsOut); |
| return *this; |
| } |
| |
| TSelf& describe(const std::string& description) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| if (reg_counter_ != 1) |
| return *this; |
| NDArrayReg().describe(description); |
| if (register_symbolic_) { |
| OpReg().describe(description); |
| } |
| return *this; |
| } |
| |
| TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo>& args) override { |
| std::lock_guard<std::mutex> lock(mutex_); |
| if (reg_counter_ != 1) |
| return *this; |
| NDArrayReg().add_arguments(args); |
| if (register_symbolic_) { |
| OpReg().add_arguments(args); |
| } |
| return *this; |
| } |
| |
| protected: |
| // make friend with unary op |
| friend class SimpleOpPropBase; |
| friend class SimpleSourceOpProp; |
| friend class SimpleUnaryOpProp; |
| friend class SimpleBinaryOpProp; |
| // internal mutex |
| std::mutex mutex_; |
| // registration counter |
| int reg_counter_{0}; |
| // whether register symbolic function. |
| bool register_symbolic_{true}; |
| // name of symbolic operator. |
| std::string symbol_name_; |
| // number of scalar arguments |
| bool enable_scalar_{false}; |
| // type mask of scalar arguments in imperative API. |
| SimpleOpScalarOption scalar_type_mask_{kArrayBeforeScalar}; |
| // whether kwargs is enabled in the function. |
| bool enable_kwargs_{false}; |
| // resource requirements |
| std::vector<ResourceRequest> resource_requests_; |
| // ------ source functions ---- |
| // source shape inference information. |
| SourceShapeFunction source_shape_{nullptr}; |
| // source functions on each device mask |
| std::vector<SourceFunction> fsource_; |
| // ------ unary functions ----- |
| // unary shape inference information. |
| UnaryShapeFunction unary_shape_{nullptr}; |
| // unary functions on each device mask |
| std::vector<UnaryFunction> funary_; |
| // type 1 gradient function |
| std::vector<UnaryGradFunctionT0> funary_grad_t0_; |
| // type 2 gradient function |
| std::vector<UnaryGradFunctionT1> funary_grad_t1_; |
| // type 2 gradient function |
| std::vector<UnaryGradFunctionT2> funary_grad_t2_; |
| // whether do inplace optimization of in 0 and output |
| bool unary_forward_inplace_in_out_{false}; |
| // whether do inplace optimization of out_grad and in_grad0 |
| bool unary_backward_inplace_out_in_{false}; |
| // ------ binary functions ----- |
| // binary shape inference information. |
| BinaryShapeFunction binary_shape_{nullptr}; |
| // unary functions on each device mask |
| std::vector<BinaryFunction> fbinary_; |
| // type 1 gradient function |
| std::vector<BinaryGradFunctionT0> fbinary_grad_t0_; |
| // type 2 gradient function |
| std::vector<BinaryGradFunctionT1> fbinary_grad_t1_; |
| // whether do inplace optimization of in 0 and output |
| bool binary_forward_inplace_lhs_out_{false}; |
| // whether do inplace optimization of out_grad and in_grad0 |
| bool binary_backward_inplace_out_lhs_{false}; |
| |
| template <typename TFunction> |
| inline void SetFunction(std::vector<TFunction>* vfunc, |
| int dev_mask, |
| TFunction func, |
| const char* type) { |
| if (vfunc->size() <= static_cast<size_t>(dev_mask)) { |
| vfunc->resize(dev_mask + 1, nullptr); |
| } |
| if (vfunc->at(dev_mask) != nullptr) { |
| LOG(FATAL) << "Device " << type << " function " << this->name |
| << " already registerd for device " << dev_mask; |
| } |
| vfunc->at(dev_mask) = func; |
| } |
| |
| private: |
| // internal reference to NDArray registry |
| NDArrayFunctionReg* ndarray_reg_{nullptr}; |
| // internal reference to operator registry |
| OperatorPropertyReg* op_reg_{nullptr}; |
| // internal function to register NDArray function. |
| inline NDArrayFunctionReg& NDArrayReg() { |
| if (ndarray_reg_ == nullptr) { |
| NDArrayFunctionReg& reg = |
| ::dmlc::Registry<NDArrayFunctionReg>::Get()->__REGISTER__(this->name); |
| ndarray_reg_ = ® |
| } |
| return *ndarray_reg_; |
| } |
| // internal function to register NDArray function. |
| inline OperatorPropertyReg& OpReg() { |
| if (op_reg_ == nullptr) { |
| if (symbol_name_.length() == 0) { |
| symbol_name_ = this->name; |
| } |
| OperatorPropertyReg& reg = |
| ::dmlc::Registry<OperatorPropertyReg>::Get()->__REGISTER__(symbol_name_); |
| op_reg_ = ® |
| } |
| return *op_reg_; |
| } |
| // register source function. |
| void RegisterSourceImperative(); |
| // register source symbolic function. |
| void RegisterSourceSymbolic(); |
| // register unary function. |
| void RegisterUnaryImperative(); |
| // register unary symbolic function. |
| void RegisterUnarySymbolic(); |
| // register unary function. |
| void RegisterBinaryImperative(); |
| // register unary symbolic function. |
| void RegisterBinarySymbolic(); |
| }; |
| |
| SimpleOpRegEntry& SimpleOpRegistry::__REGISTER_OR_FIND__(char const* name_str) { |
| std::string name(name_str); |
| if (fmap_.count(name) != 0) |
| return *fmap_.at(name); |
| SimpleOpRegEntry* e = new SimpleOpRegEntryImpl(); |
| e->name = name; |
| fmap_[name] = e; |
| return *e; |
| } |
| |
| SimpleOpRegistry* SimpleOpRegistry::Get() { |
| static SimpleOpRegistry inst; |
| return &inst; |
| } |
| |
| SimpleOpRegistry::~SimpleOpRegistry() { |
| for (auto kv : fmap_) { |
| delete kv.second; |
| } |
| } |
| |
| // base class |
| struct SimpleOpScalarParam : public dmlc::Parameter<SimpleOpScalarParam> { |
| float scalar; |
| DMLC_DECLARE_PARAMETER(SimpleOpScalarParam) { |
| DMLC_DECLARE_FIELD(scalar).describe("scalar value."); |
| } |
| }; |
| |
| DMLC_REGISTER_PARAMETER(SimpleOpScalarParam); |
| |
| class SimpleOpPropBase : public OperatorProperty { |
| public: |
| std::string name; |
| EnvArguments env; |
| SimpleOpRegEntryImpl* source; |
| |
| void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { |
| if (source->enable_kwargs_) { |
| env.kwargs = kwargs; |
| } else if (source->enable_scalar_) { |
| SimpleOpScalarParam param; |
| param.Init(kwargs); |
| env.scalar = param.scalar; |
| } else { |
| CHECK_EQ(kwargs.size(), 0) << "Operator " << source->symbol_name_ |
| << " donot accept any keyword arguments"; |
| } |
| } |
| |
| std::map<std::string, std::string> GetParams() const override { |
| if (source->enable_kwargs_) { |
| return std::map<std::string, std::string>(env.kwargs.begin(), env.kwargs.end()); |
| } else if (source->enable_scalar_) { |
| SimpleOpScalarParam param; |
| param.scalar = env.scalar; |
| return param.__DICT__(); |
| } else { |
| return std::map<std::string, std::string>(); |
| } |
| } |
| |
| std::vector<ResourceRequest> ForwardResource(const mxnet::ShapeVector& in_shape) const override { |
| return source->resource_requests_; |
| } |
| |
| std::vector<ResourceRequest> BackwardResource(const mxnet::ShapeVector& in_shape) const override { |
| return source->resource_requests_; |
| } |
| |
| bool InferType(std::vector<int>* in_type, |
| std::vector<int>* out_type, |
| std::vector<int>* aux_type) const override { |
| CHECK_LE(in_type->size(), this->ListArguments().size()); |
| int dtype = -1; |
| // reduce dtype to a common one. |
| for (int i : *in_type) { |
| if (dtype == -1) { |
| dtype = i; |
| } else { |
| CHECK(i == -1 || i == dtype) |
| << "Non-uniform input data type. Expected " << dtype << "got " << i; |
| } |
| } |
| |
| if (dtype == -1) { |
| LOG(FATAL) << "At least one input type needs to be specified."; |
| return false; |
| } |
| |
| int n_in = this->ListArguments().size(); |
| in_type->clear(); |
| for (int i = 0; i < n_in; ++i) |
| in_type->push_back(dtype); |
| |
| int n_out = this->ListOutputs().size(); |
| out_type->clear(); |
| for (int i = 0; i < n_out; ++i) |
| out_type->push_back(dtype); |
| |
| int n_aux = this->ListAuxiliaryStates().size(); |
| aux_type->clear(); |
| for (int i = 0; i < n_aux; ++i) |
| aux_type->push_back(dtype); |
| return true; |
| } |
| |
| std::string TypeString() const override { |
| return name; |
| } |
| }; |
| |
| //------------------------------------- |
| // source function Implementation |
| //------------------------------------- |
| void SimpleOpRegEntryImpl::RegisterSourceImperative() { |
| CHECK_EQ(reg_counter_, 1); |
| // The body to be registered |
| auto body = [this](NDArray** used_vars, |
| real_t* s, |
| NDArray** mutate_vars, |
| int num_params, |
| char** param_keys, |
| char** param_vals) { |
| NDArray* out = mutate_vars[0]; |
| // setup env. |
| EnvArguments env; |
| if (enable_scalar_) |
| env.scalar = s[0]; |
| if (enable_kwargs_) { |
| for (int i = 0; i < num_params; ++i) { |
| env.kwargs.emplace_back( |
| std::make_pair(std::string(param_keys[i]), std::string(param_vals[i]))); |
| } |
| } else { |
| CHECK_EQ(num_params, 0) << "operator " << this->name << " do not take keyword arguments"; |
| } |
| // shape inference. |
| CHECK(source_shape_ != nullptr); |
| mxnet::TShape dshape = source_shape_(env); |
| // check output shape. |
| CHECK(!out->is_none()); |
| CHECK(out->shape() == dshape) << "target shape mismatch " << out->shape() << " vs. " << dshape; |
| |
| // important: callback must always capture by value |
| NDArray ret = *out; |
| // request resources. |
| std::vector<Engine::VarHandle> write_vars = {ret.var()}; |
| for (ResourceRequest req : resource_requests_) { |
| env.resource.push_back(ResourceManager::Get()->Request(ret.ctx(), req)); |
| write_vars.push_back(env.resource.back().var); |
| } |
| // check if the function exist |
| int dev_mask = ret.ctx().dev_mask(); |
| // error message |
| if (static_cast<size_t>(dev_mask) >= fsource_.size() || fsource_[dev_mask] == nullptr) { |
| if (dev_mask == gpu::kDevMask) { |
| LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; |
| } |
| LOG(FATAL) << "Function " << this->name << "not registered for device " << dev_mask; |
| } |
| // invoke the function |
| SourceFunction fun = fsource_[dev_mask]; |
| OpReqType req = kWriteTo; |
| |
| Engine::Get()->PushSync( |
| [ret, fun, dev_mask, req, env](RunContext ctx) { |
| TBlob tmp = ret.data(); |
| (*fun)(env, &tmp, req, ctx); |
| }, |
| ret.ctx(), |
| {}, |
| write_vars, |
| FnProperty::kNormal, |
| 0, |
| "RegisterSourceImperative"); |
| }; |
| // register the function. |
| NDArrayReg().set_body(body).set_num_use_vars(0).set_num_mutate_vars(1); |
| if (enable_scalar_) { |
| NDArrayReg().set_num_scalars(1).add_argument("scalar", "float", "scalar input to the function"); |
| } |
| } |
| |
| // operator to invoke unary function. |
| struct SimpleSourceOperator : public Operator { |
| EnvArguments env; |
| SourceFunction forward; |
| |
| void Forward(const OpContext& ctx, |
| const std::vector<TBlob>& in_data, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& out_data, |
| const std::vector<TBlob>& aux_args) override { |
| if (ctx.requested.size() != 0) |
| env.resource = ctx.requested; |
| CHECK_EQ(in_data.size(), 0); |
| CHECK_EQ(out_data.size(), 1); |
| TBlob out = out_data[0]; |
| (*forward)(env, &out, req[0], ctx.run_ctx); |
| } |
| |
| void Backward(const OpContext& ctx, |
| const std::vector<TBlob>& out_grad, |
| const std::vector<TBlob>& in_data, |
| const std::vector<TBlob>& out_data, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& in_grad, |
| const std::vector<TBlob>& aux_args) override { |
| LOG(FATAL) << "no gradient can be done"; |
| // no nothing. |
| } |
| }; // class SimpleUnaryOperator |
| |
| class SimpleSourceOpProp : public SimpleOpPropBase { |
| public: |
| bool InferShape(mxnet::ShapeVector* in_shape, |
| mxnet::ShapeVector* out_shape, |
| mxnet::ShapeVector* aux_shape) const override { |
| CHECK_EQ(in_shape->size(), 0) << in_shape->size(); |
| CHECK(source->source_shape_ != nullptr); |
| out_shape->clear(); |
| out_shape->push_back((*(source->source_shape_))(env)); |
| return true; |
| } |
| |
| OperatorProperty* Copy() const override { |
| auto ptr = new SimpleSourceOpProp(); |
| ptr->source = source; |
| ptr->name = name; |
| ptr->env = env; |
| return ptr; |
| } |
| |
| Operator* CreateOperator(Context ctx) const override { |
| size_t dev_mask = ctx.dev_mask(); |
| SimpleSourceOperator* op = new SimpleSourceOperator(); |
| CHECK(dev_mask < source->fsource_.size() && source->fsource_[dev_mask] != nullptr); |
| op->forward = source->fsource_[dev_mask]; |
| op->env = this->env; |
| return op; |
| } |
| |
| std::vector<std::string> ListArguments() const override { |
| return {}; |
| } |
| |
| bool InferType(std::vector<int>* in_type, |
| std::vector<int>* out_type, |
| std::vector<int>* aux_type) const override { |
| out_type->clear(); |
| out_type->push_back(mshadow::kFloat32); |
| return true; |
| } |
| }; |
| |
| void SimpleOpRegEntryImpl::RegisterSourceSymbolic() { |
| // register the operator |
| auto op_factory = [this]() { |
| SimpleSourceOpProp* prop = new SimpleSourceOpProp(); |
| prop->name = this->symbol_name_; |
| prop->source = this; |
| return prop; |
| }; |
| OpReg().set_body(op_factory); |
| } |
| |
| //------------------------------------- |
| // unary function Implementation |
| //------------------------------------- |
| void SimpleOpRegEntryImpl::RegisterUnaryImperative() { |
| CHECK_EQ(reg_counter_, 1); |
| // The body to be registered |
| auto body = [this](NDArray** used_vars, |
| real_t* s, |
| NDArray** mutate_vars, |
| int num_params, |
| char** param_keys, |
| char** param_vals) { |
| NDArray& src = *used_vars[0]; |
| NDArray* out = mutate_vars[0]; |
| // setup env. |
| EnvArguments env; |
| if (enable_scalar_) |
| env.scalar = s[0]; |
| if (enable_kwargs_) { |
| for (int i = 0; i < num_params; ++i) { |
| env.kwargs.emplace_back( |
| std::make_pair(std::string(param_keys[i]), std::string(param_vals[i]))); |
| } |
| } else { |
| CHECK_EQ(num_params, 0) << "operator " << this->name << " do not take keyword arguments"; |
| } |
| // shape inference. |
| mxnet::TShape dshape; |
| if (unary_shape_ != nullptr) { |
| dshape = unary_shape_(src.shape(), env); |
| } else { |
| dshape = src.shape(); |
| } |
| // check output shape. |
| if (out->is_none()) { |
| *out = NDArray(dshape, src.ctx(), true, src.dtype()); |
| } else { |
| CHECK(out->ctx() == src.ctx()) << "target context mismatch"; |
| CHECK(out->dtype() == src.dtype()) << "target data type mismatch"; |
| CHECK(out->shape() == dshape) |
| << "target shape mismatch " << out->shape() << " vs. " << dshape; |
| } |
| // important: callback must always capture by value |
| NDArray ret = *out; |
| // get the const variables |
| std::vector<Engine::VarHandle> const_vars; |
| if (src.var() != ret.var()) { |
| const_vars.push_back(src.var()); |
| } |
| |
| // request resources. |
| std::vector<Engine::VarHandle> write_vars = {ret.var()}; |
| for (ResourceRequest req : resource_requests_) { |
| env.resource.push_back(ResourceManager::Get()->Request(src.ctx(), req)); |
| write_vars.push_back(env.resource.back().var); |
| } |
| |
| // check if the function exist |
| int dev_mask = src.ctx().dev_mask(); |
| // error message |
| if (static_cast<size_t>(dev_mask) >= funary_.size() || funary_[dev_mask] == nullptr) { |
| if (dev_mask == gpu::kDevMask) { |
| LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; |
| } |
| LOG(FATAL) << "Function " << this->name << "not registered for device " << dev_mask; |
| } |
| // invoke the function |
| UnaryFunction fun = funary_[dev_mask]; |
| OpReqType req = kWriteTo; |
| if (src.var() == ret.var()) { |
| req = kWriteInplace; |
| CHECK(unary_forward_inplace_in_out_) |
| << "inplace operation is not enabled for operator " << name; |
| } |
| |
| Engine::Get()->PushSync( |
| [src, ret, fun, dev_mask, req, env](RunContext ctx) { |
| TBlob tmp = ret.data(); |
| (*fun)(src.data(), env, &tmp, req, ctx); |
| }, |
| src.ctx(), |
| const_vars, |
| write_vars, |
| FnProperty::kNormal, |
| 0, |
| "RegisterUnaryImperative"); |
| }; |
| // register the function. |
| NDArrayReg().set_body(body).set_num_use_vars(1).set_num_mutate_vars(1); |
| if (enable_scalar_) { |
| if (scalar_type_mask_ == kArrayBeforeScalar) { |
| NDArrayReg() |
| .set_num_scalars(1) |
| .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) |
| .add_argument("src", "NDArray-or-Symbol", "Source input to the function") |
| .add_argument("scalar", "float", "scalar input to the function"); |
| } else { |
| NDArrayReg() |
| .set_num_scalars(1) |
| .set_type_mask(kScalarArgBeforeNDArray | kAcceptEmptyMutateTarget) |
| .add_argument("scalar", "float", "scalar input to the function") |
| .add_argument("src", "NDArray-or-Symbol", "Source input to the function"); |
| } |
| } else { |
| NDArrayReg() |
| .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) |
| .add_argument("src", "NDArray-or-Symbol", "Source input to the function"); |
| } |
| } |
| |
| // operator to invoke unary function. |
| struct SimpleUnaryOperator : public Operator { |
| EnvArguments env; |
| UnaryFunction forward; |
| UnaryGradFunctionT0 backward0{nullptr}; |
| UnaryGradFunctionT1 backward1{nullptr}; |
| UnaryGradFunctionT2 backward2{nullptr}; |
| |
| void Forward(const OpContext& ctx, |
| const std::vector<TBlob>& in_data, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& out_data, |
| const std::vector<TBlob>& aux_args) override { |
| if (ctx.requested.size() != 0) |
| env.resource = ctx.requested; |
| CHECK_EQ(in_data.size(), 1); |
| CHECK_EQ(out_data.size(), 1); |
| TBlob out = out_data[0]; |
| (*forward)(in_data[0], env, &out, req[0], ctx.run_ctx); |
| } |
| |
| void Backward(const OpContext& ctx, |
| const std::vector<TBlob>& out_grad, |
| const std::vector<TBlob>& in_data, |
| const std::vector<TBlob>& out_data, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& in_grad, |
| const std::vector<TBlob>& aux_args) override { |
| if (ctx.requested.size() != 0) |
| env.resource = ctx.requested; |
| CHECK_EQ(out_grad.size(), 1); |
| CHECK(in_data.size() == 1 && in_grad.size() == 1); |
| CHECK_EQ(req.size(), 1); |
| OutputGrad ograd; |
| ograd.data = out_grad[0]; |
| TBlob igrad = in_grad[0]; |
| |
| if (backward0 != nullptr) { |
| (*backward0)(ograd, env, &igrad, req[0], ctx.run_ctx); |
| } else if (backward1 != nullptr) { |
| OutputValue out_value; |
| out_value.data = out_data[0]; |
| (*backward1)(ograd, out_value, env, &igrad, req[0], ctx.run_ctx); |
| } else if (backward2 != nullptr) { |
| Input0 in0; |
| in0.data = in_data[0]; |
| (*backward2)(ograd, in0, env, &igrad, req[0], ctx.run_ctx); |
| } else { |
| LOG(FATAL) << "Backward is not supported"; |
| } |
| } |
| }; // class SimpleUnaryOperator |
| |
| class SimpleUnaryOpProp : public SimpleOpPropBase { |
| public: |
| bool InferShape(mxnet::ShapeVector* in_shape, |
| mxnet::ShapeVector* out_shape, |
| mxnet::ShapeVector* aux_shape) const override { |
| using namespace mshadow; |
| CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; |
| const mxnet::TShape& dshape = in_shape->at(0); |
| if (!shape_is_known(dshape)) |
| return false; |
| out_shape->clear(); |
| if (source->unary_shape_ == nullptr) { |
| out_shape->push_back(dshape); |
| } else { |
| out_shape->push_back((*(source->unary_shape_))(dshape, env)); |
| } |
| return true; |
| } |
| |
| OperatorProperty* Copy() const override { |
| auto ptr = new SimpleUnaryOpProp(); |
| ptr->source = source; |
| ptr->name = name; |
| ptr->env = env; |
| return ptr; |
| } |
| |
| // decalre dependency and inplace optimization options |
| std::vector<int> DeclareBackwardDependency(const std::vector<int>& out_grad, |
| const std::vector<int>& in_data, |
| const std::vector<int>& out_data) const override { |
| if (source->funary_grad_t0_.size() != 0) { |
| return {out_grad[0]}; |
| } else if (source->funary_grad_t1_.size() != 0) { |
| return {out_grad[0], out_data[0]}; |
| } else if (source->funary_grad_t2_.size() != 0) { |
| return {out_grad[0], in_data[0]}; |
| } else { |
| LOG(FATAL) << "Backward of " << name << " is not decalred"; |
| return {}; |
| } |
| } |
| |
| std::vector<std::pair<int, void*> > BackwardInplaceOption( |
| const std::vector<int>& out_grad, |
| const std::vector<int>& in_data, |
| const std::vector<int>& out_data, |
| const std::vector<void*>& in_grad) const override { |
| if (source->unary_backward_inplace_out_in_) { |
| return {{out_grad[0], in_grad[0]}}; |
| } else { |
| return {}; |
| } |
| } |
| |
| std::vector<std::pair<int, void*> > ForwardInplaceOption( |
| const std::vector<int>& in_data, |
| const std::vector<void*>& out_data) const override { |
| if (source->unary_forward_inplace_in_out_) { |
| return {{in_data[0], out_data[0]}}; |
| } else { |
| return {}; |
| } |
| } |
| |
| Operator* CreateOperator(Context ctx) const override { |
| size_t dev_mask = ctx.dev_mask(); |
| SimpleUnaryOperator* op = new SimpleUnaryOperator(); |
| CHECK(dev_mask < source->funary_.size() && source->funary_[dev_mask] != nullptr); |
| op->forward = source->funary_[dev_mask]; |
| op->env = this->env; |
| if (dev_mask < source->funary_grad_t0_.size()) { |
| op->backward0 = source->funary_grad_t0_[dev_mask]; |
| } |
| if (dev_mask < source->funary_grad_t1_.size()) { |
| op->backward1 = source->funary_grad_t1_[dev_mask]; |
| } |
| if (dev_mask < source->funary_grad_t2_.size()) { |
| op->backward2 = source->funary_grad_t2_[dev_mask]; |
| } |
| return op; |
| } |
| }; |
| |
| void SimpleOpRegEntryImpl::RegisterUnarySymbolic() { |
| // register the operator |
| auto op_factory = [this]() { |
| SimpleUnaryOpProp* prop = new SimpleUnaryOpProp(); |
| prop->name = this->symbol_name_; |
| prop->source = this; |
| return prop; |
| }; |
| OpReg() |
| .set_body(op_factory) |
| .add_argument("src", "NDArray-or-Symbol", "Left symbolic input to the function"); |
| } |
| |
| //------------------------------------- |
| // binary function Implementation |
| //------------------------------------- |
| void SimpleOpRegEntryImpl::RegisterBinaryImperative() { |
| CHECK_EQ(reg_counter_, 1); |
| // The body to be registered |
| auto body = [this](NDArray** used_vars, |
| real_t* s, |
| NDArray** mutate_vars, |
| int num_params, |
| char** param_keys, |
| char** param_vals) { |
| NDArray& lhs = *used_vars[0]; |
| NDArray& rhs = *used_vars[1]; |
| NDArray* out = mutate_vars[0]; |
| // setup env. |
| EnvArguments env; |
| if (enable_scalar_) |
| env.scalar = s[0]; |
| if (enable_kwargs_) { |
| for (int i = 0; i < num_params; ++i) { |
| env.kwargs.emplace_back( |
| std::make_pair(std::string(param_keys[i]), std::string(param_vals[i]))); |
| } |
| } else { |
| CHECK_EQ(num_params, 0) << "operator " << this->name << " do not take keyword arguments"; |
| } |
| |
| // shape inference. |
| mxnet::TShape dshape; |
| if (binary_shape_ != nullptr) { |
| dshape = binary_shape_(lhs.shape(), rhs.shape(), env); |
| } else { |
| CHECK_EQ(lhs.shape(), rhs.shape()) << "operands shape mismatch"; |
| dshape = lhs.shape(); |
| } |
| |
| // no check if all of them are on cpu |
| if (lhs.ctx().dev_mask() != cpu::kDevMask || rhs.ctx().dev_mask() != cpu::kDevMask) { |
| CHECK(lhs.ctx() == rhs.ctx()) |
| << "operands context mismatch " << lhs.ctx().dev_type << " " << lhs.ctx().dev_id |
| << " vs. " << rhs.ctx().dev_type << " " << rhs.ctx().dev_id; |
| } |
| CHECK_EQ(lhs.dtype(), rhs.dtype()) << "operands type mismatch"; |
| |
| // check output shape. |
| if (out->is_none()) { |
| *out = NDArray(dshape, lhs.ctx(), true, lhs.dtype()); |
| } else { |
| CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; |
| CHECK(out->dtype() == lhs.dtype()) << "target data type mismatch"; |
| CHECK(out->shape() == dshape) |
| << "target shape mismatch " << out->shape() << " vs. " << dshape; |
| } |
| // important: callback must always capture by value |
| NDArray ret = *out; |
| // get the const variables |
| std::vector<Engine::VarHandle> const_vars; |
| if (lhs.var() != ret.var()) |
| const_vars.push_back(lhs.var()); |
| if (rhs.var() != ret.var()) |
| const_vars.push_back(rhs.var()); |
| |
| // request resources. |
| std::vector<Engine::VarHandle> write_vars = {ret.var()}; |
| for (ResourceRequest req : resource_requests_) { |
| env.resource.push_back(ResourceManager::Get()->Request(lhs.ctx(), req)); |
| write_vars.push_back(env.resource.back().var); |
| } |
| |
| // check if the function exist |
| int dev_mask = lhs.ctx().dev_mask(); |
| // error message |
| if (static_cast<size_t>(dev_mask) >= fbinary_.size() || fbinary_[dev_mask] == nullptr) { |
| if (dev_mask == gpu::kDevMask) { |
| LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; |
| } |
| LOG(FATAL) << "Function " << this->name << "not registered for device " << dev_mask; |
| } |
| // invoke the function |
| BinaryFunction fun = fbinary_[dev_mask]; |
| OpReqType req = kWriteTo; |
| if (lhs.var() == ret.var()) { |
| req = kWriteInplace; |
| CHECK(binary_forward_inplace_lhs_out_) |
| << "inplace operation is not enabled for operator " << name; |
| } |
| if (rhs.var() == ret.var()) { |
| LOG(ERROR) << " operation " << this->name |
| << " warning, perform inplace operation with right operand, may not be supported"; |
| } |
| |
| Engine::Get()->PushSync( |
| [lhs, rhs, ret, fun, dev_mask, req, env](RunContext ctx) { |
| TBlob tmp = ret.data(); |
| (*fun)(lhs.data(), rhs.data(), env, &tmp, req, ctx); |
| }, |
| lhs.ctx(), |
| const_vars, |
| write_vars, |
| FnProperty::kNormal, |
| 0, |
| "RegisterBinaryImperative"); |
| }; |
| // register the function. |
| NDArrayReg().set_body(body).set_num_use_vars(2).set_num_mutate_vars(1); |
| if (enable_scalar_) { |
| if (scalar_type_mask_ == kArrayBeforeScalar) { |
| NDArrayReg() |
| .set_num_scalars(1) |
| .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) |
| .add_argument("lhs", "NDArray-or-Symbol", "Left operand to the function") |
| .add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function") |
| .add_argument("scalar", "float", "scalar input to the function"); |
| } else { |
| NDArrayReg() |
| .set_num_scalars(1) |
| .set_type_mask(kScalarArgBeforeNDArray | kAcceptEmptyMutateTarget) |
| .add_argument("scalar", "float", "scalar input to the function") |
| .add_argument("src", "NDArray-or-Symbol", "Source input to the function") |
| .add_argument("lhs", "NDArray-or-Symbol", "Left operand to the function") |
| .add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function"); |
| } |
| } else { |
| NDArrayReg() |
| .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) |
| .add_argument("lhs", "NDArray-or-Symbol", "Left operand to the function") |
| .add_argument("rhs", "NDArray-or-Symbol", "Right operand to the function"); |
| } |
| } |
| |
| struct SimpleBinaryOperator : public Operator { |
| EnvArguments env; |
| BinaryFunction forward; |
| BinaryGradFunctionT0 backward0{nullptr}; |
| BinaryGradFunctionT1 backward1{nullptr}; |
| |
| void Forward(const OpContext& ctx, |
| const std::vector<TBlob>& in_data, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& out_data, |
| const std::vector<TBlob>& aux_args) override { |
| if (ctx.requested.size() != 0) |
| env.resource = ctx.requested; |
| CHECK_EQ(in_data.size(), 2); |
| CHECK_EQ(out_data.size(), 1); |
| TBlob out = out_data[0]; |
| (*forward)(in_data[0], in_data[1], env, &out, req[0], ctx.run_ctx); |
| } |
| |
| void Backward(const OpContext& ctx, |
| const std::vector<TBlob>& out_grad, |
| const std::vector<TBlob>& in_data, |
| const std::vector<TBlob>& out_data, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& in_grad, |
| const std::vector<TBlob>& aux_args) override { |
| if (ctx.requested.size() != 0) |
| env.resource = ctx.requested; |
| CHECK_EQ(out_grad.size(), 1); |
| CHECK(in_data.size() == 2 && in_grad.size() == 2); |
| CHECK_EQ(req.size(), 2); |
| OutputGrad ograd; |
| ograd.data = out_grad[0]; |
| TBlob lgrad = in_grad[0]; |
| TBlob rgrad = in_grad[1]; |
| |
| if (backward0 != nullptr) { |
| (*backward0)(ograd, env, &lgrad, &rgrad, req[0], req[1], ctx.run_ctx); |
| } else if (backward1 != nullptr) { |
| Input0 in0; |
| in0.data = in_data[0]; |
| Input1 in1; |
| in1.data = in_data[1]; |
| (*backward1)(ograd, in0, in1, env, &lgrad, &rgrad, req[0], req[1], ctx.run_ctx); |
| } else { |
| LOG(FATAL) << "Backward is not supported"; |
| } |
| } |
| }; // class SimpleBinaryOperator |
| |
| class SimpleBinaryOpProp : public SimpleOpPropBase { |
| public: |
| bool InferShape(mxnet::ShapeVector* in_shape, |
| mxnet::ShapeVector* out_shape, |
| mxnet::ShapeVector* aux_shape) const override { |
| using namespace mshadow; |
| CHECK_EQ(in_shape->size(), 2) << "Input:[lhs, rhs]"; |
| const mxnet::TShape& lshape = in_shape->at(0); |
| const mxnet::TShape& rshape = in_shape->at(1); |
| out_shape->clear(); |
| if (source->binary_shape_ == nullptr) { |
| if (in_shape->at(0).ndim() != 0) { |
| SHAPE_ASSIGN_CHECK(*in_shape, 1, in_shape->at(0)); |
| } else if (in_shape->at(1).ndim() != 0) { |
| in_shape->at(0) = in_shape->at(1); |
| } else { |
| return false; |
| } |
| out_shape->push_back(lshape); |
| } else { |
| if (lshape.ndim() == 0) |
| return false; |
| if (rshape.ndim() == 0) |
| return false; |
| out_shape->push_back((*(source->binary_shape_))(lshape, rshape, env)); |
| } |
| return true; |
| } |
| |
| std::vector<std::string> ListArguments() const override { |
| return {"lhs", "rhs"}; |
| } |
| |
| OperatorProperty* Copy() const override { |
| auto ptr = new SimpleBinaryOpProp(); |
| ptr->source = source; |
| ptr->name = name; |
| ptr->env = env; |
| return ptr; |
| } |
| |
| // decalre dependency and inplace optimization options |
| std::vector<int> DeclareBackwardDependency(const std::vector<int>& out_grad, |
| const std::vector<int>& in_data, |
| const std::vector<int>& out_data) const override { |
| if (source->fbinary_grad_t0_.size() != 0) { |
| return {out_grad[0]}; |
| } else if (source->fbinary_grad_t1_.size() != 0) { |
| return {out_grad[0], in_data[0], in_data[1]}; |
| } else { |
| LOG(FATAL) << "Backward of " << name << " is not decalred"; |
| return {}; |
| } |
| } |
| |
| std::vector<std::pair<int, void*> > BackwardInplaceOption( |
| const std::vector<int>& out_grad, |
| const std::vector<int>& in_data, |
| const std::vector<int>& out_data, |
| const std::vector<void*>& in_grad) const override { |
| if (source->binary_backward_inplace_out_lhs_) { |
| return {{out_grad[0], in_grad[0]}}; |
| } else { |
| return {}; |
| } |
| } |
| |
| std::vector<std::pair<int, void*> > ForwardInplaceOption( |
| const std::vector<int>& in_data, |
| const std::vector<void*>& out_data) const override { |
| if (source->binary_forward_inplace_lhs_out_) { |
| return {{in_data[0], out_data[0]}}; |
| } else { |
| return {}; |
| } |
| } |
| |
| Operator* CreateOperator(Context ctx) const override { |
| size_t dev_mask = ctx.dev_mask(); |
| SimpleBinaryOperator* op = new SimpleBinaryOperator(); |
| CHECK(dev_mask < source->fbinary_.size() && source->fbinary_[dev_mask] != nullptr); |
| op->forward = source->fbinary_[dev_mask]; |
| op->env = this->env; |
| if (dev_mask < source->fbinary_grad_t0_.size()) { |
| op->backward0 = source->fbinary_grad_t0_[dev_mask]; |
| } |
| if (dev_mask < source->fbinary_grad_t1_.size()) { |
| op->backward1 = source->fbinary_grad_t1_[dev_mask]; |
| } |
| return op; |
| } |
| }; |
| |
| void SimpleOpRegEntryImpl::RegisterBinarySymbolic() { |
| // register the operator |
| auto op_factory = [this]() { |
| SimpleBinaryOpProp* prop = new SimpleBinaryOpProp(); |
| prop->name = symbol_name_; |
| prop->source = this; |
| return prop; |
| }; |
| OpReg() |
| .set_body(op_factory) |
| .add_argument("lhs", "NDArray-or-Symbol", "Left symbolic input to the function") |
| .add_argument("rhs", "NDArray-or-Symbol", "Right symbolic input to the function"); |
| } |
| |
| } // namespace op |
| } // namespace mxnet |