blob: f04861867560e21ac781743f67a56b21ca436d0c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file mshadow_op.h
* \brief
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_MSHADOW_OP_H_
#define MXNET_OPERATOR_MSHADOW_OP_H_
#include <mxnet/base.h>
#include <mshadow/base.h>
#include "math.h"
#include "math_functions-inl.h"
#include "special_functions-inl.h"
#include "./operator_tune.h"
#include "./contrib/erfinv-inl.h"
#ifdef __CUDACC__
#include <cuda_fp16.h>
#endif
#define MXNET_HAS_GCD_LCM() 0
#if __cplusplus >= 201703L
#ifdef __has_gcd_lcm
#if __has_gcd_lcm(<numeric>)
#include <numeric>
#undef MXNET_HAS_GCD_LCM
#define MXNET_HAS_GCD_LCM() 1
#endif
#endif
#endif
namespace mxnet {
namespace op {
namespace mshadow_op {
using mshadow::isinf_typed::IsInf;
using mshadow::isnan_typed::IsNan;
#ifdef __CUDA_ARCH__
__constant__ const float PI = 3.14159265358979323846;
__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
__constant__ const float SQRT_2 = 1.4142135623730950488016887242096;
__constant__ const float GELU_TANH_CONST = 0.044715;
#else
const float PI = 3.14159265358979323846;
const float SELU_ALPHA = 1.6732632423543772848170429916717;
const float SELU_LAMBDA = 1.0507009873554804934193349852946;
const float SQRT_2 = 1.4142135623730950488016887242096;
const float GELU_TANH_CONST = 0.044715;
#endif
using std::enable_if;
using std::is_integral;
using std::is_unsigned;
#define MXNET_UNARY_MATH_OP(name, expr) \
struct name : public mxnet_op::tunable { \
template <typename DType> \
MSHADOW_XINLINE static DType Map(DType a) { \
return DType(expr); \
} \
}
#define MXNET_UNARY_MATH_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template <typename DType> \
MSHADOW_XINLINE static DType Map(DType a) { \
return (expr); \
} \
}
#define MXNET_UNARY_LOGIC_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template <typename DType> \
MSHADOW_XINLINE static bool Map(DType a) { \
return (expr); \
} \
}
#define MXNET_BINARY_MATH_OP(name, expr) \
struct name : public mxnet_op::tunable { \
template <typename DType> \
MSHADOW_XINLINE static DType Map(DType a, DType b) { \
return DType(expr); \
} \
}
#define MXNET_BINARY_MATH_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template <typename DType> \
MSHADOW_XINLINE static DType Map(DType a, DType b) { \
return (expr); \
} \
}
#define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \
struct name : public mxnet_op::tunable { \
template <typename DType, \
typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0> \
MSHADOW_XINLINE static DType Map(DType a, DType b) { \
return (expr); \
} \
MSHADOW_XINLINE static bool Map(bool a, bool b) { \
return (expr); \
} \
}
#define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template <typename DType, typename EType> \
MSHADOW_XINLINE static bool Map(DType lhs, EType rhs) { \
double a = static_cast<double>(lhs); \
double b = static_cast<double>(rhs); \
return (expr); \
} \
}
#define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a))
#define MXNET_SIMPLE_BINARY_MATH_OP(name) MXNET_BINARY_MATH_OP(name, math::name(a, b))
MXNET_UNARY_MATH_OP_NC(identity, a);
template <typename IType, typename DType>
struct IndexedNum {
IType idx;
DType num;
MSHADOW_XINLINE IndexedNum() : idx(0), num(0) {}
MSHADOW_XINLINE IndexedNum(DType n) : idx(0), num(n) {}
MSHADOW_XINLINE IndexedNum& operator+=(const IndexedNum& rhs) {
return *this;
}
};
template <typename AType, typename IType>
struct set_index_no_op : public mxnet_op::tunable {
static const bool do_op = false;
MSHADOW_XINLINE static void Op(AType* const a, IType i) {}
};
template <typename AType, typename IType>
struct arg_min_max_set_index : public mxnet_op::tunable {
static const bool do_op = true;
MSHADOW_XINLINE static void Op(AType* const a, IType i) {
a->idx = i;
}
};
MXNET_UNARY_MATH_OP(identity_grad, 1);
struct identity_with_cast {
template <typename DTypeIn, typename DTypeOut>
MSHADOW_XINLINE static void Map(index_t i, DTypeOut* out, DTypeIn* in) {
out[i] = DTypeOut(in[i]);
}
};
struct true_divide : public mxnet_op::tunable {
template <typename DType, typename std::enable_if<!std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a / b;
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, DType b) {
return static_cast<float>(a) / static_cast<float>(b);
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) / b;
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) / b;
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) / b;
}
};
struct rtrue_divide : public mxnet_op::tunable {
template <typename DType, typename std::enable_if<!std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return b / a;
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, DType b) {
return static_cast<float>(b) / static_cast<float>(a);
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return b / static_cast<mshadow::half::half_t>(a);
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return b / static_cast<float>(a);
}
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return b / static_cast<double>(a);
}
};
/***** floor_divide ******/
struct floor_divide : public mxnet_op::tunable {
template <
typename DType,
typename std::enable_if<!std::is_same<DType, bool>::value && std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(::floor(static_cast<double>(a) / static_cast<double>(b)));
}
MSHADOW_XINLINE static bool Map(bool a, bool b) {
return static_cast<bool>(::floor(a / b));
}
MSHADOW_XINLINE static mshadow::half::half_t Map(mshadow::half::half_t a,
mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(
::floor(static_cast<float>(a) / static_cast<float>(b)));
}
template <typename DType,
typename std::enable_if<!std::is_integral<DType>::value &&
!std::is_same<DType, float>::value &&
!std::is_same<DType, mshadow::half::half_t>::value,
int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return ::floor(a / b);
}
MSHADOW_XINLINE static float Map(float a, float b) {
return ::floorf(a / b);
}
};
struct rfloor_divide : public mxnet_op::tunable {
template <
typename DType,
typename std::enable_if<!std::is_same<DType, bool>::value && std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(::floor(static_cast<double>(b) / static_cast<double>(a)));
}
MSHADOW_XINLINE static bool Map(bool a, bool b) {
return static_cast<bool>(::floor(b / a));
}
template <
typename DType,
typename std::enable_if<!std::is_integral<DType>::value && !std::is_same<DType, float>::value,
int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return ::floor(b / a);
}
MSHADOW_XINLINE static float Map(float a, float b) {
return ::floorf(b / a);
}
};
struct mixed_floor_divide {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return ::floor(a / static_cast<mshadow::half::half_t>(b));
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return ::floorf(a / static_cast<float>(b));
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return ::floor(a / static_cast<double>(b));
}
};
struct mixed_rfloor_divide {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return ::floor(b / static_cast<mshadow::half::half_t>(a));
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return ::floorf(b / static_cast<float>(a));
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return ::floor(b / static_cast<double>(a));
}
};
MXNET_BINARY_MATH_OP_NC(left, a);
MXNET_BINARY_MATH_OP_NC(right, b);
struct mixed_plus {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) + b;
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) + b;
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) + b;
}
};
struct mixed_minus {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) - b;
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) - b;
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) - b;
}
};
struct mixed_rminus {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return b - static_cast<mshadow::half::half_t>(a);
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return b - static_cast<float>(a);
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return b - static_cast<double>(a);
}
};
struct mixed_mul {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) * b;
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) * b;
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) * b;
}
};
struct mixed_power {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(math::pow(a, b));
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(math::pow(a, b));
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(math::pow(a, b));
}
};
struct mixed_rpower {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(math::pow(b, a));
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(math::pow(b, a));
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(math::pow(b, a));
}
};
#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a* b);
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b);
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(plus, a + b);
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(minus, a - b);
#pragma GCC diagnostic pop
MXNET_UNARY_MATH_OP(negation, -a);
MXNET_UNARY_MATH_OP(reciprocal, 1.0f / math::id(a));
struct bitwise_not : public mxnet_op::tunable {
template <typename DType,
typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a) {
return ~static_cast<int64_t>(a);
}
MSHADOW_XINLINE static bool Map(bool a) {
return !a;
}
};
MXNET_UNARY_MATH_OP(reciprocal_grad, -1.0f / math::sqr(a));
MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a)));
MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));
MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));
MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f - math::exp(a));
struct mish : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// reference softrelu
auto softrelu = math::log1p(math::exp(a));
if (a > DType(20.0f)) {
softrelu = a;
}
return DType(a * math::tanh(softrelu));
}
};
struct mish_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// Note: the input(a) is x(not y)
auto softrelu = math::log1p(math::exp(a));
if (a > DType(20.0f)) {
softrelu = a;
}
auto tanh_sr = math::tanh(softrelu);
auto sr_grad = 1.0f / (1.0f + math::exp(-a));
return DType(tanh_sr + a * sr_grad * (1.0f - tanh_sr * tanh_sr));
}
};
MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));
MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));
MXNET_UNARY_MATH_OP_NC(selu,
DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));
MXNET_UNARY_MATH_OP_NC(selu_grad,
DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a)));
MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);
MXNET_BINARY_MATH_OP_NC(xelu,
a > DType(0) ? a : DType(static_cast<float>(a) * static_cast<float>(b)));
MXNET_BINARY_MATH_OP_NC(xelu_grad, a > DType(0) ? DType(1) : b);
MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a : DType(math::id(b) * math::expm1(a)));
MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a));
MXNET_SIMPLE_UNARY_MATH_OP(tanh);
MXNET_UNARY_MATH_OP(tanh_grad, 1.0f - math::sqr(a));
/*! \brief SoftReLU, also known as softplus activation */
struct softrelu : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// Avoid overflow of exp for large inputs.
// Thresholds 20.0 is chosen such that softrelu(a) = a
// for a > 20 using floating precision
if (a > DType(20.0f)) {
return a;
} else {
return DType(math::log1p(math::exp(a)));
}
}
};
MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(a)));
MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a)));
MXNET_SIMPLE_UNARY_MATH_OP(erf);
MXNET_UNARY_MATH_OP(gelu_erf,
DType(0.5f * static_cast<float>(a) *
(1.0f + math::erf(static_cast<float>(a) / SQRT_2))));
MXNET_BINARY_MATH_OP_NC(gelu_erf_grad,
DType(static_cast<float>(b) / static_cast<float>(a) +
0.5f * static_cast<float>(a) *
erf_grad::Map(static_cast<float>(a) / SQRT_2) / SQRT_2));
MXNET_UNARY_MATH_OP(gelu_tanh,
DType(0.5f * static_cast<float>(a) *
(1.0f +
math::tanh(math::sqrt(2.0f / PI) *
(static_cast<float>(a) +
GELU_TANH_CONST * math::pow(static_cast<float>(a), 3))))));
MXNET_BINARY_MATH_OP_NC(
gelu_tanh_grad,
DType(static_cast<float>(b) *
(1.0f / static_cast<float>(a) +
(1.0f - math::tanh(math::sqrt(2.0f / PI) *
(static_cast<float>(a) +
GELU_TANH_CONST * math::pow(static_cast<float>(a), 3))) *
(math::sqrt(2.0f / PI) *
(1.0f + 3.0f * GELU_TANH_CONST * math::pow(static_cast<float>(a), 2)))))));
MXNET_SIMPLE_UNARY_MATH_OP(exp);
MXNET_SIMPLE_UNARY_MATH_OP(expm1);
MXNET_SIMPLE_UNARY_MATH_OP(log);
MXNET_UNARY_MATH_OP(log_grad, 1.0f / math::id(a));
MXNET_SIMPLE_UNARY_MATH_OP(log10);
// Constant is 1 / log(10)
struct log10_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(0.4342944819f / static_cast<float>(a));
}
};
template <>
MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
return 0.43429448190325182765 / a;
}
MXNET_SIMPLE_UNARY_MATH_OP(log2);
// Constant is 1 / log(2)
struct log2_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1.442695041f / static_cast<float>(a));
}
};
template <>
MSHADOW_XINLINE double log2_grad::Map<double>(double a) {
return 1.44269504088896340737 / a;
}
MXNET_SIMPLE_UNARY_MATH_OP(sin);
MXNET_UNARY_MATH_OP(sin_grad, math::cos(a));
MXNET_SIMPLE_UNARY_MATH_OP(log1p);
MXNET_UNARY_MATH_OP(log1p_grad, 1.0f / (1.0f + math::id(a)));
MXNET_SIMPLE_UNARY_MATH_OP(cos);
MXNET_UNARY_MATH_OP(cos_grad, -math::sin(a));
MXNET_SIMPLE_UNARY_MATH_OP(tan);
MXNET_UNARY_MATH_OP(tan_grad, math::sqr(a) + 1.0f);
MXNET_UNARY_MATH_OP(arcsin, math::asin(a));
MXNET_UNARY_MATH_OP(arcsin_grad, 1.0f / math::sqrt(1.0f - math::sqr(a)));
MXNET_UNARY_MATH_OP(arccos, math::acos(a));
MXNET_UNARY_MATH_OP(arccos_grad, -1.0f / math::sqrt(1.0f - math::sqr(a)));
MXNET_UNARY_MATH_OP(arctan, math::atan(a));
MXNET_UNARY_MATH_OP(arctan_grad, 1.0f / (math::sqr(a) + 1.0f));
MXNET_SIMPLE_BINARY_MATH_OP(hypot);
MXNET_BINARY_MATH_OP(hypot_grad_left, math::id(a) / math::hypot(a, b));
MXNET_BINARY_MATH_OP(hypot_grad_right, math::id(b) / math::hypot(a, b));
MXNET_UNARY_MATH_OP(degrees, 180.0f / PI * math::id(a));
MXNET_UNARY_MATH_OP(degrees_grad, 180.0f / PI);
MXNET_UNARY_MATH_OP(radians, PI / 180.0f * math::id(a));
MXNET_UNARY_MATH_OP(radians_grad, PI / 180.0f);
MXNET_SIMPLE_UNARY_MATH_OP(sinh);
MXNET_UNARY_MATH_OP(sinh_grad, math::cosh(a));
MXNET_SIMPLE_UNARY_MATH_OP(cosh);
MXNET_UNARY_MATH_OP(cosh_grad, math::sinh(a));
MXNET_UNARY_MATH_OP(arcsinh, math::asinh(a));
MXNET_UNARY_MATH_OP(arcsinh_grad, 1.0f / math::hypot(a, DType(1)));
MXNET_UNARY_MATH_OP(arccosh, math::acosh(a));
MXNET_UNARY_MATH_OP(arccosh_grad, 1.0f / math::sqrt(math::sqr(a) - 1.0f));
MXNET_UNARY_MATH_OP(arctanh, math::atanh(a));
MXNET_UNARY_MATH_OP(arctanh_grad, 1.0f / (1.0f - math::sqr(a)));
MXNET_UNARY_MATH_OP(square, math::sqr(a));
MXNET_UNARY_MATH_OP(square_grad, 2.0f * math::id(a));
/*! \brief used for generate Bernoulli mask */
MXNET_BINARY_MATH_OP_NC(threshold, a < b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(threshold_eq, a <= b ? DType(1) : DType(0));
/*! \brief used for generate element of abs */
MXNET_UNARY_MATH_OP(abs, math::fabs(a)); // NOLINT(*)
/*! \brief used for generate element of sign */
struct sign : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type Map(DType a) {
if (a < DType(0))
return DType(-DType(1));
if (a > DType(0))
return DType(1);
return DType(0);
}
template <typename DType>
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type Map(DType a) {
if (a > DType(0))
return DType(1);
return DType(0);
}
};
MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
/*! \brief used for generate element of power */
MXNET_BINARY_MATH_OP(power, math::pow(a, b));
MXNET_BINARY_MATH_OP(power_grad, math::pow(a, b - DType(1)) * math::id(b));
MXNET_BINARY_MATH_OP(power_rgrad, math::pow(a, b) * math::log(a));
MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));
MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));
MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b));
MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b)));
MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b)));
MXNET_BINARY_MATH_OP(rarctan2, math::atan2(b, a));
MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b)));
MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1));
MXNET_UNARY_LOGIC_OP_NC(np_logical_not, !static_cast<bool>(a));
MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(gt, a > b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(lt, a < b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(le, a <= b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0));
MXNET_BINARY_LOGIC_OP_NC(np_greater_equal, a >= b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_greater, a > b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_less, a < b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_less_equal, a <= b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_equal, a == b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_not_equal, a != b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_logical_and, a&& b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_logical_or, a || b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_logical_xor, (a || b) && !(a && b) ? true : false);
MXNET_BINARY_MATH_OP(logical_and, a&& b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP(bitwise_and, static_cast<int64_t>(a) & static_cast<int64_t>(b));
MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>(b));
MXNET_BINARY_MATH_OP(bitwise_or, static_cast<int64_t>(a) | static_cast<int64_t>(b));
#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif
/*! \brief used for generate element of bitwise_left_shift */
struct bitwise_left_shift : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (static_cast<uint64_t>(b) >= (sizeof(DType) * CHAR_BIT)) {
return DType(0);
}
return static_cast<int64_t>(a) << static_cast<int64_t>(b);
}
};
MXNET_BINARY_MATH_OP(bitwise_left_shift_grad, math::pow(2.0f, static_cast<int64_t>(b)));
MXNET_BINARY_MATH_OP(bitwise_left_shift_rgrad,
static_cast<int64_t>(a) * math::pow(2.0f, static_cast<int64_t>(b)) *
math::log(2.0f));
MXNET_BINARY_MATH_OP(rbitwise_left_shift, static_cast<int64_t>(b) << static_cast<int64_t>(a));
MXNET_BINARY_MATH_OP(rbitwise_left_shift_grad,
static_cast<int64_t>(b) * math::pow(2.0f, static_cast<int64_t>(a)) *
math::log(2.0f));
/*! \brief used for generate element of bitwise_right_shift */
struct bitwise_right_shift : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (static_cast<uint64_t>(b) >= (sizeof(DType) * CHAR_BIT)) {
if (a < 0) {
return DType(-1);
} else {
return DType(0);
}
}
if constexpr (std::is_integral<DType>::value)
return a >> b;
else
return static_cast<int64_t>(a) >> static_cast<int64_t>(b);
}
};
MXNET_BINARY_MATH_OP(bitwise_right_shift_grad, math::pow(0.5f, static_cast<int64_t>(b)));
MXNET_BINARY_MATH_OP(bitwise_right_shift_rgrad,
static_cast<int64_t>(a) * math::pow(0.5f, static_cast<int64_t>(b)) *
math::log(0.5f));
MXNET_BINARY_MATH_OP(rbitwise_right_shift, static_cast<int64_t>(b) >> static_cast<int64_t>(a));
MXNET_BINARY_MATH_OP(rbitwise_right_shift_grad,
static_cast<int64_t>(b) * math::pow(0.5f, static_cast<int64_t>(a)) *
math::log(0.5f));
#pragma GCC diagnostic pop
MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));
MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
MXNET_UNARY_MATH_OP(reciprocal_square_root, 1.0f / math::sqrt(a));
MXNET_UNARY_MATH_OP(reciprocal_square_root_grad, -0.5f / (math::sqrt(a) * math::id(a)));
MXNET_UNARY_MATH_OP(cube_root, math::cbrt(a));
MXNET_UNARY_MATH_OP(cube_root_grad, 1.0f / (3.0f * math::sqr(a)));
MXNET_UNARY_MATH_OP(reciprocal_cube_root, 1.0f / math::cbrt(a));
MXNET_UNARY_MATH_OP(reciprocal_cube_root_grad, -1.0f / (3.0f * math::cbrt(a) * math::id(a)));
/*! \brief used for generate element of ldexp */
MXNET_BINARY_MATH_OP(ldexp, math::id(a) * math::pow(2.0f, b));
MXNET_BINARY_MATH_OP(ldexp_grad, math::pow(2.0f, b));
MXNET_BINARY_MATH_OP(ldexp_rgrad, math::id(a) * math::pow(2.0f, b) * math::log(2.0f));
MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a)); // swap a and b if a is scalar.
MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));
/*! \brief used for generate element of logaddexp */
MXNET_BINARY_MATH_OP(logaddexp, math::log(math::exp(a) + math::exp(b)));
MXNET_BINARY_MATH_OP(logaddexp_grad, math::exp(a) / (math::exp(a) + math::exp(b)));
MXNET_BINARY_MATH_OP(logaddexp_rgrad, math::exp(b) / (math::exp(a) + math::exp(b)));
/*! \brief used for generate element of round */
MXNET_SIMPLE_UNARY_MATH_OP(round);
/*! \brief used for generate element of ceil */
MXNET_SIMPLE_UNARY_MATH_OP(ceil);
/*! \brief used for generate element of floor */
MXNET_SIMPLE_UNARY_MATH_OP(floor);
/*! \brief used to round towards zero */
MXNET_SIMPLE_UNARY_MATH_OP(trunc);
/*! \brief used to round number to nearest integer */
struct rint : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
auto floor = math::floor(a);
auto ceil = math::ceil(a);
auto af = math::id(a);
return DType((af - floor) <= (ceil - af) ? floor : ceil);
}
};
/*! \brief used to round number to integer nearest to 0 */
struct fix : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
auto floor = math::floor(a);
auto ceil = math::ceil(a);
return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil);
}
};
#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif
/*! \brief used to determine whether a number is Not A Number*/
struct isnan : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static bool Map(DType a) {
return IsNan(a);
}
};
/*! \brief used to determine whether a number is infinite*/
struct isinf : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static bool Map(DType a) {
return IsInf(a);
}
};
/*! \brief used to determine whether a number is finite*/
struct isfinite : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static bool Map(DType a) {
return !IsNan(a) && !IsInf(a);
}
};
/*! \brief used to determine whether a number is positive infinity*/
struct isposinf : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static bool Map(DType a) {
return IsInf(a) && a > 0;
}
};
/*! \brief used to determine whether a number is negative infinity*/
struct isneginf : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static bool Map(DType a) {
return IsInf(a) && a < 0;
}
};
#pragma GCC diagnostic pop
/*! \brief used for generate gradient of MAE loss*/
MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));
MXNET_BINARY_MATH_OP(rminus, b - a);
MXNET_BINARY_MATH_OP_NC(posone, 1);
MXNET_BINARY_MATH_OP_NC(negone, -1);
MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b));
MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b));
MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));
MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));
MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a);
MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1 : -1);
MXNET_BINARY_MATH_OP(copysign_rgrad, 0);
MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b);
struct mod : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type Map(DType a,
DType b) {
if (b == DType(0)) {
return DType(0);
} else if (b < DType(0)) {
if (a < DType(0)) {
return DType(-::fmod(-static_cast<double>(a), -static_cast<double>(b)));
} else if (a == DType(0)) {
return -DType(0);
} else {
DType ret = DType(
::fmod(static_cast<double>(a), -static_cast<double>(b)) +
(::fmod(static_cast<double>(a), -static_cast<double>(b)) != DType(0) ? b : DType(0)));
if (ret == 0) {
return -ret;
}
return ret;
}
} else {
if (a < DType(0)) {
return DType(
-::fmod(-static_cast<double>(a), static_cast<double>(b)) +
(::fmod(-static_cast<double>(a), static_cast<double>(b)) != DType(0) ? b : DType(0)));
} else {
return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
}
}
}
template <typename DType>
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type Map(DType a,
DType b) {
if (b == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
}
}
};
struct mixed_mod {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return mod::Map(static_cast<mshadow::half::half_t>(a), b);
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return mod::Map(static_cast<float>(a), b);
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return mod::Map(static_cast<double>(a), b);
}
};
struct mixed_rmod {
template <typename DType, typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return mod::Map(b, static_cast<mshadow::half::half_t>(a));
}
template <typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value ||
std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return mod::Map(b, static_cast<float>(a));
}
template <
typename DType,
typename std::enable_if<
std::is_same<DType, mshadow::half::half_t>::value || std::is_same<DType, float>::value ||
std::is_same<DType, mshadow::bfloat::bf16_t>::value || std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return mod::Map(b, static_cast<double>(a));
}
};
struct fmod : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (b == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
}
}
};
struct rfmod : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (a == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
}
}
};
struct mod_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(0);
}
};
template <>
MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) {
return 1.0;
}
template <>
MSHADOW_XINLINE float mod_grad::Map<float>(float a, float b) {
return 1.0f;
}
template <>
MSHADOW_XINLINE mshadow::half::half_t mod_grad::Map<mshadow::half::half_t>(
mshadow::half::half_t a,
mshadow::half::half_t b) {
return mshadow::half::half_t(1.0f);
}
struct mod_rgrad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(0);
}
};
template <>
MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) {
return -::floor(a / b);
}
template <>
MSHADOW_XINLINE float mod_rgrad::Map<float>(float a, float b) {
return -::floorf(a / b);
}
template <>
MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map<mshadow::half::half_t>(
mshadow::half::half_t a,
mshadow::half::half_t b) {
return mshadow::half::half_t(-::floorf(static_cast<float>(a) / static_cast<float>(b)));
}
struct rmod : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type Map(DType a,
DType b) {
if (a == DType(0)) {
return DType(0);
} else if (a < DType(0)) {
if (b < DType(0)) {
return DType(-::fmod(-static_cast<double>(b), -static_cast<double>(a)));
} else {
return DType(
::fmod(static_cast<double>(b), -static_cast<double>(a)) +
(::fmod(static_cast<double>(b), -static_cast<double>(a)) != DType(0) ? a : DType(0)));
}
} else {
if (b < DType(0)) {
return DType(
-::fmod(-static_cast<double>(b), static_cast<double>(a)) +
(::fmod(-static_cast<double>(b), static_cast<double>(a)) != DType(0) ? a : DType(0)));
} else {
return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
}
}
}
template <typename DType>
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type Map(DType a,
DType b) {
if (a == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
}
}
};
struct rmod_grad {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(0);
}
};
template <>
MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) {
return -::floor(b / a);
}
template <>
MSHADOW_XINLINE float rmod_grad::Map<float>(float a, float b) {
return -::floorf(b / a);
}
template <>
MSHADOW_XINLINE mshadow::half::half_t rmod_grad::Map<mshadow::half::half_t>(
mshadow::half::half_t a,
mshadow::half::half_t b) {
return mshadow::half::half_t(-::floorf(static_cast<float>(b / a)));
}
struct clip : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType bound) {
if (x > bound) {
return bound;
} else if (x < -bound) {
return -bound;
} else {
return x;
}
}
template <typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) {
if (x > upper_bound) {
return upper_bound;
} else if (x < lower_bound) {
return lower_bound;
}
return x;
}
};
/***** gamma ******/
MXNET_UNARY_MATH_OP(gamma, math::tgamma(a));
struct gamma_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
float af(static_cast<float>(a));
return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af));
}
};
template <>
MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
return math::tgamma(a) * special_functions::cephes::psi<double>(a);
}
/***** gammaln ******/
MXNET_UNARY_MATH_OP(gammaln, math::lgamma(a));
struct gammaln_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(special_functions::cephes::psi<float>(a));
}
};
template <>
MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
return special_functions::cephes::psi<double>(a);
}
/***** digamma ******/
struct digamma : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(special_functions::cephes::psi<float>(a));
}
};
template <>
MSHADOW_XINLINE double digamma::Map<double>(double a) {
return special_functions::cephes::psi<double>(a);
}
/***** trigamma ******/
struct trigamma : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(special_functions::trigamma<float>(a));
}
};
template <>
MSHADOW_XINLINE double trigamma::Map<double>(double a) {
return special_functions::trigamma<double>(a);
}
/* Smooth L1 Loss is a loss specific for R-CNN franchise training
* Smooth L1 Loss function:
* f(x) = 0.5 * (sigma * x) ^ 2, |x| < 1 / sigma^2
* = |x| - 0.5 / sigma / sigma, otherwise
* When sigma = 1, it is equivalent to the Huber loss, evaluated at
* delta = 1.
* smooth_l1_loss = w_out * f(w_in * x)
* with w_in, w_out provided by input_data.
*/
struct smooth_l1_loss : public mxnet_op::tunable {
// a is x, b is sigma
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
auto bsq = math::sqr(b);
auto ibsq = 1.0f / bsq;
auto af = math::id(a);
if (af > ibsq) {
return DType(af - 0.5f * ibsq);
} else if (af < -ibsq) {
return DType(-af - 0.5f * ibsq);
} else {
return DType(0.5f * af * af * bsq);
}
}
}; // struct smooth_l1_loss
/* The derivative of smooth l1 loss is
* f'(x) = sigma^2 * x, |x| < 1 / sigma^2
* = sign(x), otherwise
*/
struct smooth_l1_gradient : public mxnet_op::tunable {
// a is x, b is sigma2
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
auto bsq = math::sqr(b);
auto ibsq = 1.0f / bsq;
auto af = math::id(a);
if (af > ibsq) {
return DType(1);
} else if (af < -ibsq) {
return DType(-1);
} else {
return DType(bsq * af);
}
}
}; // struct smooth_l1_derivative
/* Implicti reparameterization gradient for standard x ~ Gamma(\alpha, 1)
* according to dx/da = -cdf(x;alpha) / pdf(x;alpha)
*/
struct gamma_implicit_grad {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType x) {
if (x < 0.8f) {
DType numer = 1;
DType denom = a;
DType series1 = numer / denom;
DType series2 = numer / (denom * denom);
for (int i = 1; i <= 5; i++) {
numer *= -x / static_cast<DType>(i);
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
DType pow_x_alpha = math::pow(x, a);
DType gamma_pdf = math::pow(x, a - 1) * math::exp(-x);
DType gamma_cdf = pow_x_alpha * series1;
DType gamma_cdf_alpha =
(math::log(x) - DType(special_functions::cephes::psi<float>(a))) * gamma_cdf -
pow_x_alpha * series2;
DType result = -gamma_cdf_alpha / gamma_pdf;
return IsNan(result) ? static_cast<DType>(0.f) : static_cast<DType>(result);
}
if (a > 8.0f) {
if (0.9f * a <= x && x <= 1.1f * a) {
DType numer_1 = 1 + 24 * a * (1 + 12 * a);
DType numer_2 =
1440 * (a * a) + 6 * x * (53 - 120 * x) - 65 * x * x / a + a * (107 + 3600 * x);
DType denom = 1244160 * (a * a) * (a * a);
return static_cast<DType>(numer_1 * numer_2 / denom);
}
DType denom = math::sqrt(8 * a);
DType term2 = denom / (a - x);
DType term3 = math::pow(x - a - a * math::log(x / a), static_cast<DType>(-1.5));
DType term23 = (x < a) ? term2 - term3 : term2 + term3;
DType term1 = math::log(x / a) * term23 - math::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
DType stirling = 1 + 1 / (12 * a) * (1 + 1 / (24 * a));
DType numer = x * term1;
return static_cast<DType>(-stirling * numer / denom);
}
DType u = math::log(x / a);
DType v = math::log(a);
DType coef_uv[3][8] = {
{0.16009398,
-0.094634809,
0.025146376,
-0.0030648343,
1,
0.32668115,
0.10406089,
0.0014179084},
{0.53487893,
0.1298071,
0.065735949,
-0.0015649758,
0.16639465,
0.020070113,
-0.0035938915,
-0.00058392623},
{0.040121004,
-0.0065914022,
-0.0026286047,
-0.0013441777,
0.017050642,
-0.0021309326,
0.00085092367,
-1.5247877e-07},
};
DType coef_v[8];
for (int i = 0; i < 8; i++) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
DType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
DType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return static_cast<DType>(math::exp(p / q));
}
}; // gamma_implicit_grad
/*! \brief product reducer */
struct product {
/*! \brief do reduction into dst */
template <typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
dst *= src;
}
/*! \brief do reduction into dst */
template <typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, // NOLINT(*)
volatile DType src,
volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& dst_residual, // NOLINT(*)
volatile DType& src_val, // NOLINT(*)
volatile DType& src_residual) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template <typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redres / redsrc;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv) { // NOLINT(*)
initv = 1;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv, DType& none) { // NOLINT(*)
SetInitValue(initv);
}
};
MXNET_UNARY_MATH_OP_NC(relu, IsNan(a) || (a > DType(0)) ? a : DType(0));
/*! \brief used for computing gradient of relu operator */
struct relu_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
if (IsNan(a)) {
return a;
} else {
return a > DType(0) ? DType(1) : DType(0);
}
}
};
/*! \brief used for computing binary operator maximum */
struct maximum : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (IsNan(a)) {
return a;
} else {
return (a > b ? a : b);
}
}
};
/*! \brief used for computing binary operator fmax */
struct fmax : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (IsNan(b)) {
return a;
} else if (IsNan(a)) {
return b;
} else {
return (a > b ? a : b);
}
}
};
/*! \brief used for computing binary operator minimum */
struct minimum : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (IsNan(a)) {
return a;
} else {
return DType(a < b ? a : b);
}
}
};
/*! \brief used for computing binary operator fmin */
struct fmin : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (IsNan(b)) {
return a;
} else if (IsNan(a)) {
return b;
} else {
return (a < b ? a : b);
}
}
};
/*! \brief boolean any/all kernel that determines whether elem is NonZero */
struct NonZero {
template <typename DType>
MSHADOW_XINLINE static bool Map(DType a) {
return (a != DType(0));
}
};
/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
template <typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
if (IsNan(src))
return;
dst += src;
}
/*! \brief do reduction into dst */
template <typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, // NOLINT(*)
volatile DType src, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
if (IsNan(src))
return;
DType y = src - residual;
DType t = dst + y;
residual = (t - dst) - y;
dst = t;
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& dst_residual, // NOLINT(*)
volatile DType& src_val, // NOLINT(*)
volatile DType& src_residual) { // NOLINT(*)
DType t1 = dst_val + src_val;
DType e = t1 - src_val;
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
dst_val = t1 + t2;
dst_residual = t2 - (dst_val - t1);
}
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv) { // NOLINT(*)
initv = 0;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv, DType& residual) { // NOLINT(*)
SetInitValue(initv);
residual = 0;
}
};
struct nansum_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return IsNan(a) ? DType(0) : DType(1);
}
};
/*! \brief product reducer that ignores NaN values in the input */
struct nanprod {
/*! \brief do reduction into dst */
template <typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
if (IsNan(src))
return;
dst *= src;
}
/*! \brief do reduction into dst */
template <typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, // NOLINT(*)
volatile DType src, // NOLINT(*)
volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& dst_residual, // NOLINT(*)
volatile DType& src_val, // NOLINT(*)
volatile DType& src_residual) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv) { // NOLINT(*)
initv = 1;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv, DType& none) { // NOLINT(*)
SetInitValue(initv);
}
};
/*! \brief compute l2 norm */
struct nrm2 {
/*! \brief do reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, // NOLINT(*)
volatile DType src) { // NOLINT(*)
sum_of_squares += src * src;
}
/*! \brief do stable reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, // NOLINT(*)
volatile DType src, // NOLINT(*)
volatile DType& scale) { // NOLINT(*)
if (src != 0) {
DType abs = mshadow_op::abs::Map(src);
if (scale < abs) {
sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
scale = abs;
} else {
sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
}
}
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& src_val) { // NOLINT(*)
dst_val += src_val;
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, // NOLINT(*)
volatile DType& dst_scale, // NOLINT(*)
volatile DType& src_ssq, // NOLINT(*)
volatile DType& src_scale) { // NOLINT(*)
if (dst_scale != 0 && dst_scale >= src_scale) {
dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale);
} else if (src_scale != 0 && dst_scale < src_scale) {
dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale);
dst_scale = src_scale;
}
}
/*! \brief finalize reduction result */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*)
sum_of_squares = math::sqrt(sum_of_squares);
}
/*! \brief finalize reduction result */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, // NOLINT(*)
volatile DType& scale) { // NOLINT(*)
#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#endif
sum_of_squares = scale * math::sqrt(sum_of_squares);
#pragma GCC diagnostic pop
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template <typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redsrc / redres;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& sum_of_squares) { // NOLINT(*)
sum_of_squares = 0;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& sum_of_squares, DType& scale) { // NOLINT(*)
SetInitValue(sum_of_squares);
scale = 0;
}
};
/*! \brief sum reducer */
struct sum {
/*! \brief do reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*)
dst += src;
}
/*! \brief do stable reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, // NOLINT(*)
volatile DType src, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
DType y = src - residual;
DType t = dst + y;
residual = (t - dst) - y;
dst = t;
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& dst_residual, // NOLINT(*)
volatile DType& src_val, // NOLINT(*)
volatile DType& src_residual) { // NOLINT(*)
DType t1 = dst_val + src_val;
DType e = t1 - dst_val;
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
dst_val = t1 + t2;
dst_residual = t2 - (dst_val - t1);
}
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template <typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv) { // NOLINT(*)
initv = 0;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv, DType& residual) { // NOLINT(*)
SetInitValue(initv);
residual = 0;
}
};
/*! \brief arg max reducer */
struct argmax {
/*! \brief do reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*)
if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
dst.num = src.num;
dst.idx = src.idx;
}
}
/*! \brief do stable reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, // NOLINT(*)
volatile DType src, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
dst.num = src.num;
dst.idx = src.idx;
}
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& src_val) { // NOLINT(*)
if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) {
dst_val.num = src_val.num;
dst_val.idx = src_val.idx;
}
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& dst_residual, // NOLINT(*)
volatile DType& src_val, // NOLINT(*)
volatile DType& src_residual) { // NOLINT(*)
if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) {
dst_val.num = src_val.num;
dst_val.idx = src_val.idx;
}
}
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template <typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv) { // NOLINT(*)
initv.num = mshadow::red::limits::NegInfValue<decltype(initv.num)>();
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv, DType& residual) { // NOLINT(*)
initv.num = mshadow::red::limits::NegInfValue<decltype(initv.num)>();
}
};
/*! \brief arg max reducer */
struct argmin {
/*! \brief do reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*)
if (dst.num > src.num) {
dst.num = src.num;
dst.idx = src.idx;
}
}
/*! \brief do stable reduction into dst */
template <typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, // NOLINT(*)
volatile DType src, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
if (dst.num > src.num) {
dst.num = src.num;
dst.idx = src.idx;
}
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& src_val) { // NOLINT(*)
if (dst_val.num > src_val.num) {
dst_val.num = src_val.num;
dst_val.idx = src_val.idx;
}
}
/*! \brief combine the results of two reducers */
template <typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, // NOLINT(*)
volatile DType& dst_residual, // NOLINT(*)
volatile DType& src_val, // NOLINT(*)
volatile DType& src_residual) { // NOLINT(*)
if (dst_val.num > src_val.num) {
dst_val.num = src_val.num;
dst_val.idx = src_val.idx;
}
}
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template <typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, // NOLINT(*)
volatile DType& residual) { // NOLINT(*)
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template <typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv) { // NOLINT(*)
initv.num = mshadow::red::limits::PosInfValue<decltype(initv.num)>();
}
/*!
*\brief set the initial value during reduction
*/
template <typename DType>
MSHADOW_XINLINE static void SetInitValue(DType& initv, DType& residual) { // NOLINT(*)
initv.num = mshadow::red::limits::PosInfValue<decltype(initv.num)>();
}
};
struct nanprod_grad : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return IsNan(a) ? DType(0) : b / a;
}
};
#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif
/*! \brief used for computing binary greatest common divisor */
struct gcd : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type Map(DType a,
DType b) {
#if MXNET_HAS_GCD_LCM()
return std::gcd(a, b);
#else
// minus cases.
if (a < 0) {
a = -a;
}
if (b < 0) {
b = -b;
}
// handle zero-valued cases.
DType c;
if (a == 0 && b != 0) {
c = b;
} else if (b == 0 && a != 0) {
c = a;
} else if (a == 0 && b == 0) {
c = 0;
} else {
DType tmp;
if (a < b) {
tmp = a;
a = b;
b = tmp;
}
while (a % b != 0) {
a = a % b;
tmp = a;
a = b;
b = tmp;
}
c = b;
}
return c;
#endif
}
template <typename DType>
MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type Map(DType a,
DType b) {
return DType(0.0f);
}
};
/*! \brief used for computing binary lowest common multiple */
struct lcm : public mxnet_op::tunable {
template <typename DType>
MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type Map(DType a,
DType b) {
#if MXNET_HAS_GCD_LCM()
return std::lcm(a, b);
#else
// minus cases.
if (a < 0) {
a = -a;
}
if (b < 0) {
b = -b;
}
// handle zero-valued cases.
DType c;
if (a == 0 || b == 0) {
c = 0;
} else {
DType tmp;
DType tmp_a = a;
DType tmp_b = b;
if (a < b) {
tmp = a;
a = b;
b = tmp;
}
while (a % b != 0) {
a = a % b;
tmp = a;
a = b;
b = tmp;
}
c = tmp_a / b * tmp_b;
}
return c;
#endif
}
template <typename DType>
MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type Map(DType a,
DType b) {
return DType(0.0f);
}
};
#pragma GCC diagnostic pop
} // namespace mshadow_op
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_MSHADOW_OP_H_