| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file elemwise_unary_op_basic.cc |
| * \brief CPU Implementation of elementwise unary function. |
| */ |
| #include <mxnet/base.h> |
| |
| #include "../nn/dnnl/dnnl_ops-inl.h" |
| #include "./elemwise_binary_op-inl.h" |
| #include "elemwise_unary_op.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| // infer storage function for _identity_with_attr_like_rhs op |
| static bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, |
| const int dev_mask, |
| DispatchMode* dispatch_mode, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| const auto& rhs_stype = in_attrs->at(1); |
| auto& lhs_stype = in_attrs->at(0); |
| auto& out_stype = out_attrs->at(0); |
| bool dispatched = false; |
| |
| CHECK_NE(rhs_stype, kUndefinedStorage); |
| type_assign(&out_stype, rhs_stype); |
| type_assign(&lhs_stype, rhs_stype); |
| if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage && |
| out_stype == kDefaultStorage) { |
| // dns, dns -> dns |
| dispatched = |
| storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); |
| } |
| if (!dispatched && (lhs_stype == kRowSparseStorage || lhs_stype == kCSRStorage) && |
| (lhs_stype == out_stype)) { |
| // rsp, _ -> rsp, or csr, _ -> csr |
| dispatched = storage_type_assign(&out_stype, |
| static_cast<NDArrayStorageType>(out_stype), |
| dispatch_mode, |
| DispatchMode::kFComputeEx); |
| } |
| if (!dispatched && (lhs_stype == kRowSparseStorage || lhs_stype == kCSRStorage) && |
| (out_stype == kDefaultStorage)) { |
| // rsp/csr, _ -> dns |
| dispatched = storage_type_assign(&out_stype, |
| static_cast<NDArrayStorageType>(out_stype), |
| dispatch_mode, |
| DispatchMode::kFComputeEx); |
| } |
| if (!dispatched) { |
| dispatched = dispatch_fallback(out_attrs, dispatch_mode); |
| } |
| return dispatched; |
| } |
| |
| // relu |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(relu, cpu, mshadow_op::relu) |
| .describe(R"code(Computes rectified linear activation. |
| |
| .. math:: |
| max(features, 0) |
| |
| The storage type of ``relu`` output depends upon the input storage type: |
| |
| - relu(default) = default |
| - relu(row_sparse) = row_sparse |
| - relu(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu, unary_bwd<mshadow_op::relu_grad>) |
| .set_attr<nnvm::FGradient>( |
| "FGradient", |
| [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) { |
| std::vector<nnvm::NodeEntry> ret; |
| // ograds[0]: dL/dxgrad |
| // inputs[0]: dL/dy |
| // inputs[1]: y |
| // f(x) -> relu(x) |
| // f'(x) = 1 if x > 0 else 0 |
| // f''(x) = 0 |
| auto dydx = |
| MakeNode("_greater", |
| n->attrs.name + "_dydx", |
| {n->inputs[1], |
| nnvm::NodeEntry{MakeNode( |
| "zeros_like", n->attrs.name + "tmp", {n->inputs[1]}, nullptr, &n)}}, |
| nullptr, |
| &n); |
| ret.emplace_back(MakeNode("elemwise_mul", |
| n->attrs.name + "_backward_grad_grad", |
| {ograds[0], nnvm::NodeEntry(dydx)}, |
| nullptr, |
| &n)); |
| ret.emplace_back(MakeNode( |
| "zeros_like", n->attrs.name + "_backward_grad_grad_in", {n->inputs[1]}, nullptr, &n)); |
| return ret; |
| }); |
| |
| // sigmoid |
| MXNET_OPERATOR_REGISTER_UNARY(sigmoid) |
| MXNET_ADD_SPARSE_OP_ALIAS(sigmoid) |
| .describe(R"code(Computes sigmoid of x element-wise. |
| |
| .. math:: |
| y = 1 / (1 + exp(-x)) |
| |
| The storage type of ``sigmoid`` output is always dense |
| |
| )code" ADD_FILELINE) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::sigmoid>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid, |
| unary_bwd<mshadow_op::sigmoid_grad>) |
| .set_attr<nnvm::FGradient>( |
| "FGradient", |
| [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) { |
| // n->inputs[0] : y_grad |
| // n->inputs[1] : f(x) = sigmoid(x) |
| // ograds[0] : head_grads |
| // f''(x) = f'(x) * (1 - 2*f(x)) |
| // NodeEntry{n} : y_grad * f'(x) |
| auto ones = |
| MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n); |
| const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}}; |
| auto two_y = |
| MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n); |
| auto one_minus_two_y = MakeNode("elemwise_sub", |
| n->attrs.name + "_grad_sub", |
| {nnvm::NodeEntry{ones}, nnvm::NodeEntry{two_y}}, |
| nullptr, |
| &n); |
| auto grad_grad_mid = MakeNode("elemwise_mul", |
| n->attrs.name + "_grad_mul", |
| {n->inputs[0], nnvm::NodeEntry{one_minus_two_y}}, |
| nullptr, |
| &n); |
| auto dydx = MakeNode("elemwise_div", |
| n->attrs.name + "_grad_div", |
| {nnvm::NodeEntry{n}, n->inputs[0]}, |
| nullptr, |
| &n); |
| |
| // when building gradient graph, the backward node of n->inputs[1] will be |
| // added to the graph again, therefore f`(x) will be multiplied |
| std::vector<nnvm::NodeEntry> ret; |
| ret.emplace_back(MakeNode("elemwise_mul", |
| n->attrs.name + "backward_grad_grad", |
| {ograds[0], nnvm::NodeEntry{dydx}}, |
| nullptr, |
| &n)); |
| ret.emplace_back(MakeNode("elemwise_mul", |
| n->attrs.name + "backward_grad_grad_in", |
| {ograds[0], nnvm::NodeEntry{grad_grad_mid}}, |
| nullptr, |
| &n)); |
| return ret; |
| }); |
| |
| // log_sigmoid |
| MXNET_OPERATOR_REGISTER_UNARY(log_sigmoid) |
| MXNET_ADD_SPARSE_OP_ALIAS(log_sigmoid) |
| .describe(R"code(Computes log_sigmoid of x element-wise. |
| |
| .. math:: |
| y = log(1 / (1 + exp(-x))) |
| |
| The storage type of ``log_sigmoid`` output is always dense |
| |
| )code" ADD_FILELINE) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::log_sigmoid>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_log_sigmoid"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid, |
| unary_bwd<mshadow_op::log_sigmoid_grad>) |
| .set_attr<nnvm::FGradient>( |
| "FGradient", |
| [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) { |
| // n->inputs[0] : y_grad |
| // n->inputs[1] : f(x) = log_sigmoid(x) |
| // ograds[0] : head_grads |
| // f''(x) = f'(x) * (f'(x) - 1) |
| // NodeEntry{n} : y_grad * f'(x) |
| auto ones = |
| MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n); |
| auto grad_minus_one = MakeNode("elemwise_sub", |
| n->attrs.name + "_grad_sub", |
| {n->inputs[0], nnvm::NodeEntry{ones}}, |
| nullptr, |
| &n); |
| auto grad_grad_mid = MakeNode("elemwise_mul", |
| n->attrs.name + "_grad_mul", |
| {n->inputs[0], nnvm::NodeEntry{grad_minus_one}}, |
| nullptr, |
| &n); |
| auto dydx = MakeNode("elemwise_div", |
| n->attrs.name + "_grad_div", |
| {nnvm::NodeEntry{n}, n->inputs[0]}, |
| nullptr, |
| &n); |
| |
| // when building gradient graph, the backward node of n->inputs[1] will be |
| // added to the graph again, therefore f`(x) will be multiplied |
| std::vector<nnvm::NodeEntry> ret; |
| ret.emplace_back(MakeNode("elemwise_mul", |
| n->attrs.name + "backward_grad_grad", |
| {ograds[0], nnvm::NodeEntry{dydx}}, |
| nullptr, |
| &n)); |
| ret.emplace_back(MakeNode("elemwise_mul", |
| n->attrs.name + "backward_grad_grad_in", |
| {ograds[0], nnvm::NodeEntry{grad_grad_mid}}, |
| nullptr, |
| &n)); |
| return ret; |
| }); |
| |
| // mish |
| MXNET_OPERATOR_REGISTER_UNARY(mish) |
| MXNET_ADD_SPARSE_OP_ALIAS(mish) |
| .describe(R"code(Computes mish of x element-wise. |
| |
| .. math:: |
| y = x * tanh(log(1 + exp(x))) |
| |
| The storage type of ``mish`` output is always dense |
| |
| )code" ADD_FILELINE) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::mish>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mish"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_mish, unary_bwd<mshadow_op::mish_grad>); |
| |
| DMLC_REGISTER_PARAMETER(HardSigmoidParam); |
| MXNET_OPERATOR_REGISTER_UNARY(hard_sigmoid) |
| .describe(R"code(Computes hard sigmoid of x element-wise. |
| |
| .. math:: |
| y = max(0, min(1, alpha * x + beta)) |
| |
| )code" ADD_FILELINE) |
| .set_attr_parser(ParamParser<HardSigmoidParam>) |
| .set_attr<FCompute>("FCompute<cpu>", HardSigmoidForward<cpu>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_hard_sigmoid"}) |
| .add_arguments(HardSigmoidParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_hard_sigmoid) |
| .set_attr_parser(ParamParser<HardSigmoidParam>) |
| .set_num_inputs(2) |
| .set_num_outputs(1) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::pair<int, int>>{{0, 0}}; |
| }) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", HardSigmoidBackward<cpu>); |
| |
| // softsign |
| MXNET_OPERATOR_REGISTER_UNARY(softsign) |
| .describe(R"code(Computes softsign of x element-wise. |
| |
| .. math:: |
| y = x / (1 + abs(x)) |
| |
| The storage type of ``softsign`` output is always dense |
| |
| )code" ADD_FILELINE) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::softsign>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_softsign"}); |
| |
| NNVM_REGISTER_OP(softsign).add_alias("_npx_softsign"); |
| |
| MXNET_OPERATOR_REGISTER_BINARY(_backward_softsign) |
| .set_attr<FCompute>("FCompute<cpu>", |
| ElemwiseBinaryOp::Compute<cpu, unary_bwd<mshadow_op::softsign_grad>>); |
| |
| // copy |
| static void CopyEx(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<NDArray>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<NDArray>& outputs) { |
| CHECK_EQ(inputs.size(), 1U); |
| CHECK_EQ(outputs.size(), 1U); |
| #if MXNET_USE_ONEDNN == 1 |
| const auto in_stype = inputs[0].storage_type(); |
| const auto out_stype = outputs[0].storage_type(); |
| if (inputs[0].IsDNNLData()) { |
| DNNLRun(DNNLCopy, attrs, ctx, inputs[0], req[0], outputs[0]); |
| return; |
| } else if (in_stype == kDefaultStorage && out_stype == kDefaultStorage) { |
| if (req[0] != kNullOp && req[0] != kWriteInplace) |
| FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs); |
| return; |
| } |
| #endif // MXNET_USE_ONEDNN == 1 |
| UnaryOp::IdentityComputeEx<cpu>(attrs, ctx, inputs, req, outputs); |
| } |
| |
| static inline bool CopyStorageType(const nnvm::NodeAttrs& attrs, |
| const int dev_mask, |
| DispatchMode* dispatch_mode, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1); |
| CHECK_EQ(out_attrs->size(), 1); |
| bool ret = ElemwiseStorageType<1, 1, false, true, true>( |
| attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); |
| #if MXNET_USE_ONEDNN == 1 |
| // We have to make sure all inputs are default layouts. Otherwise, we might |
| // want to fallback. |
| if (dev_mask == mshadow::cpu::kDevMask && in_attrs->at(0) == kDefaultStorage && |
| out_attrs->at(0) == kDefaultStorage) { |
| *dispatch_mode = DispatchMode::kFComputeEx; |
| } |
| #endif // MXNET_USE_ONEDNN == 1 |
| return ret; |
| } |
| |
| MXNET_OPERATOR_REGISTER_UNARY(_copy) |
| .MXNET_DESCRIBE("Returns a copy of the input.") |
| .add_alias("identity") |
| .set_attr<FInferStorageType>("FInferStorageType", CopyStorageType) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) |
| .set_attr<FComputeEx>("FComputeEx<cpu>", CopyEx) |
| #if MXNET_USE_ONEDNN == 1 |
| .set_attr<FResourceRequest>("FResourceRequest", |
| [](const NodeAttrs& n) { |
| return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; |
| }) |
| .set_attr<bool>("TIsDNNL", true) |
| #endif // MXNET_USE_ONEDNN == 1 |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"}); |
| |
| NNVM_REGISTER_OP(_backward_copy) |
| .set_num_inputs(1) |
| .set_num_outputs(1) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::pair<int, int>>{{0, 0}}; |
| }) |
| .set_attr<FInferStorageType>("FInferStorageType", CopyStorageType) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) |
| .set_attr<FComputeEx>("FComputeEx<cpu>", CopyEx) |
| #if MXNET_USE_ONEDNN == 1 |
| .set_attr<bool>("TIsDNNL", true) |
| .set_attr<FResourceRequest>("FResourceRequest", |
| [](const NodeAttrs& n) { |
| return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; |
| }) // MXNET_USE_ONEDNN == 1 |
| #endif |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }); |
| |
| NNVM_REGISTER_OP(_backward_reshape) |
| .set_num_inputs(1) |
| .set_num_outputs(1) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::pair<int, int>>{{0, 0}}; |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }); |
| |
| MXNET_OPERATOR_REGISTER_UNARY(BlockGrad) |
| MXNET_ADD_SPARSE_OP_ALIAS(stop_gradient) |
| .add_alias("_npx_stop_gradient") |
| .add_alias("stop_gradient") |
| .describe(R"code(Stops gradient computation. |
| |
| Stops the accumulated gradient of the inputs from flowing through this operator |
| in the backward direction. In other words, this operator prevents the contribution |
| of its inputs to be taken into account for computing gradients. |
| |
| Example:: |
| |
| v1 = [1, 2] |
| v2 = [0, 1] |
| a = Variable('a') |
| b = Variable('b') |
| b_stop_grad = stop_gradient(3 * b) |
| loss = MakeLoss(b_stop_grad + a) |
| |
| executor = loss.simple_bind(ctx=cpu(), a=(1,2), b=(1,2)) |
| executor.forward(is_train=True, a=v1, b=v2) |
| executor.outputs |
| [ 1. 5.] |
| |
| executor.backward() |
| executor.grad_arrays |
| [ 0. 0.] |
| [ 1. 1.] |
| |
| )code" ADD_FILELINE) |
| .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) |
| .set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::IdentityComputeEx<cpu>) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| MXNET_OPERATOR_REGISTER_UNARY(make_loss) |
| MXNET_ADD_SPARSE_OP_ALIAS(make_loss) |
| .describe(R"code(Make your own loss function in network construction. |
| |
| This operator accepts a customized loss function symbol as a terminal loss and |
| the symbol should be an operator with no backward dependency. |
| The output of this function is the gradient of loss with respect to the input data. |
| |
| For example, if you are a making a cross entropy loss function. Assume ``out`` is the |
| predicted output and ``label`` is the true label, then the cross entropy can be defined as:: |
| |
| cross_entropy = label * log(out) + (1 - label) * log(1 - out) |
| loss = make_loss(cross_entropy) |
| |
| We will need to use ``make_loss`` when we are creating our own loss function or we want to |
| combine multiple loss functions. Also we may want to stop some variables' gradients |
| from backpropagation. See more detail in ``BlockGrad`` or ``stop_gradient``. |
| |
| The storage type of ``make_loss`` output depends upon the input storage type: |
| |
| - make_loss(default) = default |
| - make_loss(row_sparse) = row_sparse |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"loss"}; |
| }) |
| .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) |
| .set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::IdentityComputeEx<cpu>) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<nnvm::FGradient>( |
| "FGradient", |
| [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) { |
| std::vector<nnvm::NodeEntry> ret; |
| ret.emplace_back( |
| MakeNode("ones_like", n->attrs.name + "_backward", &(n->inputs), nullptr, &n)); |
| return ret; |
| }); |
| |
| // identity output as first input, but attributes (shape and type) are constrained to be like rhs |
| // storage type attribute is not constrained to be like rhs if it is already defined |
| // for internal use only |
| NNVM_REGISTER_OP(_identity_with_attr_like_rhs) |
| .set_num_inputs(2) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"lhs", "rhs"}; |
| }) |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::pair<int, int>>{{0, 0}}; |
| }) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs", |
| [](const NodeAttrs& attrs) { |
| return std::vector<uint32_t>(1, 1); |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) |
| .set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::IdentityComputeFirstItemEx<cpu>) |
| .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>) |
| .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) |
| .set_attr<FInferStorageType>("FInferStorageType", IdentityAttrLikeRhsStorageType) |
| .set_attr<nnvm::FGradient>( |
| "FGradient", |
| [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) { |
| if (CheckGradAllZero(ograds)) |
| return MakeZeroGradNodes(n, ograds); |
| std::vector<nnvm::NodeEntry> lhs = MakeGradNode( |
| "_backward_copy", n, ograds, std::unordered_map<std::string, std::string>()); |
| lhs.emplace_back( |
| MakeNode("zeros_like", n->attrs.name + "_rhs_backward", {n->inputs[1]}, nullptr, &n)); |
| return lhs; |
| }) |
| .add_argument("lhs", "NDArray-or-Symbol", "First input.") |
| .add_argument("rhs", "NDArray-or-Symbol", "Second input."); |
| |
| void ReshapeLikeRangeCanonicalize(int ndims, |
| const char* side, |
| const dmlc::optional<int>& begin, |
| const dmlc::optional<int>& end, |
| int* cbegin, |
| int* cend) { |
| *cbegin = begin.has_value() ? begin.value() : 0; |
| if (*cbegin < 0) |
| *cbegin += ndims; |
| |
| if (!end.has_value()) { |
| *cend = ndims; |
| } else { |
| *cend = end.value(); |
| if (*cend < 0) { |
| *cend += ndims; |
| } |
| } |
| CHECK(*cend <= ndims) << "Invalid end for " << side << "_end=" << end |
| << " as dimension number is " << ndims; |
| CHECK((*cbegin < *cend)) << "Invalid begin, end, get " << side << "_begin=" << begin << ", " |
| << side << "_end=" << end; |
| |
| CHECK(*cend >= 0) << "Invalid end for " << side << "_end=" << end; |
| CHECK(*cbegin >= 0) << "Invalid begin for " << side << "_begin=" << begin; |
| } |
| |
| void GetReshapeLikeParams(const ReshapeLikeParam& param, |
| const mxnet::TShape& lshape, |
| const mxnet::TShape& rshape, |
| int* lhs_begin, |
| int* lhs_end, |
| int* rhs_begin, |
| int* rhs_end) { |
| // LHS params |
| ReshapeLikeRangeCanonicalize( |
| lshape.ndim(), "lhs", param.lhs_begin, param.lhs_end, lhs_begin, lhs_end); |
| // RHS params |
| ReshapeLikeRangeCanonicalize( |
| rshape.ndim(), "rhs", param.rhs_begin, param.rhs_end, rhs_begin, rhs_end); |
| } |
| |
| bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| const ReshapeLikeParam& param = nnvm::get<ReshapeLikeParam>(attrs.parsed); |
| const mxnet::TShape& lshape = (*in_attrs)[0]; |
| const mxnet::TShape& rshape = (*in_attrs)[1]; |
| int lhs_begin, lhs_end, rhs_begin, rhs_end; |
| GetReshapeLikeParams(param, lshape, rshape, &lhs_begin, &lhs_end, &rhs_begin, &rhs_end); |
| |
| int lhsrank = lshape.ndim(); |
| int orank = lhsrank + (rhs_end - rhs_begin) - (lhs_end - lhs_begin); |
| mxnet::TShape oshape(orank, -1); |
| |
| for (int i = 0; i < lhs_begin; ++i) |
| oshape[i] = lshape[i]; |
| |
| int opos = lhs_begin; |
| for (int i = rhs_begin; i < rhs_end; ++i) { |
| oshape[opos] = rshape[i]; |
| opos += 1; |
| } |
| |
| for (int i = lhs_end; i < lhsrank; ++i) { |
| oshape[opos] = lshape[i]; |
| opos += 1; |
| } |
| |
| CHECK_EQ((*in_attrs)[0].Size(), oshape.Size()) |
| << "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to new " |
| << "shape " << oshape << " because they have different " |
| << "size."; |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); |
| return shape_is_known(oshape); |
| } |
| |
| DMLC_REGISTER_PARAMETER(ReshapeLikeParam); |
| NNVM_REGISTER_OP(reshape_like) |
| .describe( |
| R"code(Reshape some or all dimensions of `lhs` to have the same shape as some or all dimensions of `rhs`. |
| |
| Returns a **view** of the `lhs` array with a new shape without altering any data. |
| |
| Example:: |
| |
| x = [1, 2, 3, 4, 5, 6] |
| y = [[0, -4], [3, 2], [2, 2]] |
| reshape_like(x, y) = [[1, 2], [3, 4], [5, 6]] |
| |
| More precise control over how dimensions are inherited is achieved by specifying \ |
| slices over the `lhs` and `rhs` array dimensions. Only the sliced `lhs` dimensions \ |
| are reshaped to the `rhs` sliced dimensions, with the non-sliced `lhs` dimensions staying the same. |
| |
| Examples:: |
| |
| - lhs shape = (30,7), rhs shape = (15,2,4), lhs_begin=0, lhs_end=1, rhs_begin=0, rhs_end=2, output shape = (15,2,7) |
| - lhs shape = (3, 5), rhs shape = (1,15,4), lhs_begin=0, lhs_end=2, rhs_begin=1, rhs_end=2, output shape = (15) |
| |
| Negative indices are supported, and `None` can be used for either `lhs_end` or `rhs_end` to indicate the end of the range. |
| |
| Example:: |
| |
| - lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhs_begin=-1, lhs_end=None, rhs_begin=1, rhs_end=None, output shape = (30, 2, 2, 3) |
| |
| )code" ADD_FILELINE) |
| .add_alias("_npx_reshape_like") |
| .set_num_inputs(2) |
| .set_attr_parser(ParamParser<ReshapeLikeParam>) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"lhs", "rhs"}; |
| }) |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::pair<int, int>>{{0, 0}}; |
| }) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs", |
| [](const NodeAttrs& attrs) { |
| return std::vector<uint32_t>(1, 1); |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) |
| .set_attr<mxnet::FInferShape>("FInferShape", ReshapeLikeShapeCompute) |
| .set_attr<nnvm::FInferType>( |
| "FInferType", |
| [](const nnvm::NodeAttrs& attrs, std::vector<int>* in_attrs, std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name; |
| std::vector<int> checked_in_attrs = {(*in_attrs)[0]}; |
| bool ret = !type_is_none((*in_attrs)[1]) && |
| ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs); |
| (*in_attrs)[0] = checked_in_attrs[0]; |
| return ret; |
| }) |
| .set_attr<nnvm::FGradient>( |
| "FGradient", |
| [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) { |
| if (CheckGradAllZero(ograds)) |
| return MakeZeroGradNodes(n, ograds); |
| std::vector<nnvm::NodeEntry> lhs = MakeGradNode( |
| "_backward_copy", n, ograds, std::unordered_map<std::string, std::string>()); |
| lhs.emplace_back( |
| MakeNode("zeros_like", n->attrs.name + "_rhs_backward", {n->inputs[1]}, nullptr, &n)); |
| return lhs; |
| }) |
| .add_argument("lhs", "NDArray-or-Symbol", "First input.") |
| .add_argument("rhs", "NDArray-or-Symbol", "Second input.") |
| .add_arguments(ReshapeLikeParam::__FIELDS__()); |
| |
| void ShapeComputeCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| CHECK_EQ(inputs.size(), 1U); |
| CHECK_EQ(outputs.size(), 1U); |
| CHECK_EQ(req.size(), 1U); |
| const TBlob& in_data = inputs[0]; |
| const TBlob& out_data = outputs[0]; |
| size_t type_size = mshadow::mshadow_sizeof(out_data.type_flag_); |
| memcpy(out_data.dptr_, in_data.shape_.data(), in_data.ndim() * type_size); |
| } |
| |
| NNVM_REGISTER_OP(shape_array) |
| .add_alias("_npx_shape_array") |
| .describe(R"code(Returns a 1D int64 array containing the shape of data. |
| |
| Example:: |
| |
| shape_array([[1,2,3,4], [5,6,7,8]]) = [2,4] |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(1) |
| .set_num_outputs(1) |
| .set_attr<FCompute>("FCompute<cpu>", ShapeComputeCPU) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) |
| .set_attr<mxnet::FInferShape>("FInferShape", |
| [](const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| mxnet::TShape target_shape(1, -1); |
| target_shape[0] = in_attrs->at(0).ndim(); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape); |
| return !shape_is_none(out_attrs->at(0)); |
| }) |
| .set_attr<nnvm::FInferType>("FInferType", |
| [](const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); |
| return out_attrs->at(0) != -1; |
| }) |
| .add_argument("data", "NDArray-or-Symbol", "Input Array."); |
| |
| void SizeComputeCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mxnet_op; |
| CHECK_EQ(inputs.size(), 1U); |
| CHECK_EQ(outputs.size(), 1U); |
| CHECK_EQ(req.size(), 1U); |
| const TBlob& in_data = inputs[0]; |
| const TBlob& out_data = outputs[0]; |
| size_t type_size = mshadow::mshadow_sizeof(out_data.type_flag_); |
| const index_t size_var = in_data.Size(); |
| memcpy(out_data.dptr_, &size_var, type_size); |
| } |
| |
| NNVM_REGISTER_OP(size_array) |
| .describe(R"code(Returns a 1D int64 array containing the size of data. |
| |
| Example:: |
| |
| size_array([[1,2,3,4], [5,6,7,8]]) = [8] |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(1) |
| .set_num_outputs(1) |
| .set_attr<FCompute>("FCompute<cpu>", SizeComputeCPU) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) |
| .set_attr<mxnet::FInferShape>("FInferShape", |
| [](const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 1)); |
| return !shape_is_none(out_attrs->at(0)); |
| }) |
| .set_attr<nnvm::FInferType>("FInferType", |
| [](const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); |
| return out_attrs->at(0) != -1; |
| }) |
| .add_argument("data", "NDArray-or-Symbol", "Input Array."); |
| |
| DMLC_REGISTER_PARAMETER(CastParam); |
| NNVM_REGISTER_OP(Cast) |
| .add_alias("cast") |
| .add_alias("_npi_cast") |
| .add_alias("_npx_cast") |
| .describe(R"code(Casts all elements of the input to a new type. |
| |
| .. note:: ``Cast`` is deprecated. Use ``cast`` instead. |
| |
| Example:: |
| |
| cast([0.9, 1.3], dtype='int32') = [0, 1] |
| cast([1e20, 11.1], dtype='float16') = [inf, 11.09375] |
| cast([300, 11.1, 10.9, -1, -3], dtype='uint8') = [44, 11, 10, 255, 253] |
| |
| )code" ADD_FILELINE) |
| .set_attr_parser(ParamParser<CastParam>) |
| .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) |
| .set_attr<nnvm::FInferType>("FInferType", CastType) |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::pair<int, int>>{{0, 0}}; |
| }) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", CastCompute<cpu>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_cast"}) |
| .add_argument("data", "NDArray-or-Symbol", "The input.") |
| .add_arguments(CastParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_cast) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::pair<int, int>>{{0, 0}}; |
| }) |
| .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", |
| [](const NodeAttrs& attrs) { |
| return std::vector<bool>{true}; |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", CastCompute<cpu>); |
| |
| // negative |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(negative, cpu, mshadow_op::negation) |
| .describe(R"code(Numerical negative of the argument, element-wise. |
| |
| The storage type of ``negative`` output depends upon the input storage type: |
| |
| - negative(default) = default |
| - negative(row_sparse) = row_sparse |
| - negative(csr) = csr |
| |
| )code") |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"negative"}); |
| |
| // abs |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(abs, cpu, mshadow_op::abs) |
| .describe(R"code(Returns element-wise absolute value of the input. |
| |
| Example:: |
| |
| abs([-2, 0, 3]) = [2, 0, 3] |
| |
| The storage type of ``abs`` output depends upon the input storage type: |
| |
| - abs(default) = default |
| - abs(row_sparse) = row_sparse |
| - abs(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_abs"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_abs, unary_bwd<mshadow_op::sign>) |
| .set_attr<nnvm::FGradient>( |
| "FGradient", |
| [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) { |
| // ograds[0]: dL/dxgrad |
| // inputs[0]: dL/dy |
| // inputs[1]: x |
| // f(x) -> abs(x) |
| // f'(x) = 1 if x > 0 else -1 |
| // f''(x) = 0 |
| auto dydx = MakeNode("elemwise_div", |
| n->attrs.name + "_dydx", |
| {nnvm::NodeEntry{n}, n->inputs[0]}, |
| nullptr, |
| &n); |
| |
| std::vector<nnvm::NodeEntry> ret; |
| ret.emplace_back(MakeNode("elemwise_mul", |
| n->attrs.name + "_backward_grad_grad", |
| {ograds[0], nnvm::NodeEntry(dydx)}, |
| nullptr, |
| &n)); |
| ret.emplace_back(MakeNode( |
| "zeros_like", n->attrs.name + "_backward_grad_grad_in", {n->inputs[1]}, nullptr, &n)); |
| return ret; |
| }); |
| |
| // sign |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(sign, cpu, mshadow_op::sign) |
| .describe(R"code(Returns element-wise sign of the input. |
| |
| Example:: |
| |
| sign([-2, 0, 3]) = [-1, 0, 1] |
| |
| The storage type of ``sign`` output depends upon the input storage type: |
| |
| - sign(default) = default |
| - sign(row_sparse) = row_sparse |
| - sign(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| // round |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(round, cpu, mshadow_op::round) |
| .describe(R"code(Returns element-wise rounded value to the nearest integer of the input. |
| |
| Example:: |
| |
| round([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 2., -2., 2., 2.] |
| |
| The storage type of ``round`` output depends upon the input storage type: |
| |
| - round(default) = default |
| - round(row_sparse) = row_sparse |
| - round(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| // rint |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(rint, cpu, mshadow_op::rint) |
| .describe(R"code(Returns element-wise rounded value to the nearest integer of the input. |
| |
| .. note:: |
| - For input ``n.5`` ``rint`` returns ``n`` while ``round`` returns ``n+1``. |
| - For input ``-n.5`` both ``rint`` and ``round`` returns ``-n-1``. |
| |
| Example:: |
| |
| rint([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 1., -2., 2., 2.] |
| |
| The storage type of ``rint`` output depends upon the input storage type: |
| |
| - rint(default) = default |
| - rint(row_sparse) = row_sparse |
| - rint(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| // ceil |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(ceil, cpu, mshadow_op::ceil) |
| .describe(R"code(Returns element-wise ceiling of the input. |
| |
| The ceil of the scalar x is the smallest integer i, such that i >= x. |
| |
| Example:: |
| |
| ceil([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-2., -1., 2., 2., 3.] |
| |
| The storage type of ``ceil`` output depends upon the input storage type: |
| |
| - ceil(default) = default |
| - ceil(row_sparse) = row_sparse |
| - ceil(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| // floor |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(floor, cpu, mshadow_op::floor) |
| .describe(R"code(Returns element-wise floor of the input. |
| |
| The floor of the scalar x is the largest integer i, such that i <= x. |
| |
| Example:: |
| |
| floor([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-3., -2., 1., 1., 2.] |
| |
| The storage type of ``floor`` output depends upon the input storage type: |
| |
| - floor(default) = default |
| - floor(row_sparse) = row_sparse |
| - floor(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| // trunc |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(trunc, cpu, mshadow_op::trunc) |
| .describe(R"code(Return the element-wise truncated value of the input. |
| |
| The truncated value of the scalar x is the nearest integer i which is closer to |
| zero than x is. In short, the fractional part of the signed number x is discarded. |
| |
| Example:: |
| |
| trunc([-2.1, -1.9, 1.5, 1.9, 2.1]) = [-2., -1., 1., 1., 2.] |
| |
| The storage type of ``trunc`` output depends upon the input storage type: |
| |
| - trunc(default) = default |
| - trunc(row_sparse) = row_sparse |
| - trunc(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| // fix |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(fix, cpu, mshadow_op::fix) |
| .describe(R"code(Returns element-wise rounded value to the nearest \ |
| integer towards zero of the input. |
| |
| Example:: |
| |
| fix([-2.1, -1.9, 1.9, 2.1]) = [-2., -1., 1., 2.] |
| |
| The storage type of ``fix`` output depends upon the input storage type: |
| |
| - fix(default) = default |
| - fix(row_sparse) = row_sparse |
| - fix(csr) = csr |
| |
| )code" ADD_FILELINE) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| // erf |
| MXNET_OPERATOR_REGISTER_UNARY(erf) |
| .add_alias("_npx_erf") |
| .describe(R"code(Returns element-wise gauss error function of the input. |
| |
| Example:: |
| |
| erf([0, -1., 10.]) = [0., -0.8427, 1.] |
| |
| )code" ADD_FILELINE) |
| #if MSHADOW_USE_MKL == 1 |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::MKL_Compute<mshadow_op::erf, mkl_func::erf>) |
| #else |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::erf>) |
| #endif // MSHADOW_USE_MKL == 1 |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_erf"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY(_backward_erf) |
| .set_attr<FCompute>("FCompute<cpu>", |
| ElemwiseBinaryOp::Compute<cpu, unary_bwd<mshadow_op::erf_grad>>); |
| |
| // erfinv |
| MXNET_OPERATOR_REGISTER_UNARY(erfinv) |
| .add_alias("_npx_erfinv") |
| .describe(R"code(Returns element-wise inverse gauss error function of the input. |
| |
| Example:: |
| |
| erfinv([0, 0.5., -1.]) = [0., 0.4769, -inf] |
| |
| )code" ADD_FILELINE) |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::erfinv>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_erfinv"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY(_backward_erfinv) |
| .set_attr<FCompute>("FCompute<cpu>", |
| ElemwiseBinaryOp::Compute<cpu, unary_bwd<mshadow_op::erfinv_grad>>); |
| |
| // gamma |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(gamma, cpu, mshadow_op::gamma) |
| MXNET_ADD_SPARSE_OP_ALIAS(gamma) |
| .add_alias("_npx_gamma") |
| .describe(R"code(Returns the gamma function (extension of the factorial function \ |
| to the reals), computed element-wise on the input array. |
| |
| The storage type of ``gamma`` output is always dense |
| |
| )code") |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_gamma"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_gamma, |
| unary_bwd<mshadow_op::gamma_grad>); |
| |
| // gammaln |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(gammaln, cpu, mshadow_op::gammaln) |
| .add_alias("_npx_gammaln") MXNET_ADD_SPARSE_OP_ALIAS(gammaln) |
| .describe(R"code(Returns element-wise log of the absolute value of the gamma function \ |
| of the input. |
| |
| The storage type of ``gammaln`` output is always dense |
| |
| )code") |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_gammaln"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_gammaln, |
| unary_bwd<mshadow_op::gammaln_grad>); |
| |
| // digamma |
| MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(digamma, cpu, mshadow_op::digamma) |
| .add_alias("_npx_digamma") MXNET_ADD_SPARSE_OP_ALIAS(digamma) |
| .describe(R"code(Returns element-wise log derivative of the gamma function \ |
| of the input. |
| |
| The storage type of ``digamma`` output is always dense |
| |
| )code") |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_digamma"}); |
| |
| MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_digamma, |
| unary_bwd<mshadow_op::trigamma>); |
| |
| MXNET_OPERATOR_REGISTER_UNARY(logical_not) |
| .describe(R"code(Returns the result of logical NOT (!) function |
| |
| Example: |
| logical_not([-2., 0., 1.]) = [0., 1., 0.] |
| |
| )code") |
| .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::nt>) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); |
| |
| } // namespace op |
| } // namespace mxnet |