blob: e14c8dbb0b78378cb869c600cf4b1318e9452949 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file mshadow_op.h
* \brief
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_MSHADOW_OP_H_
#define MXNET_OPERATOR_MSHADOW_OP_H_
#include <mxnet/base.h>
#include "math.h"
#include "math_functions-inl.h"
#include "special_functions-inl.h"
#include "./operator_tune.h"
#include "./contrib/erfinv-inl.h"
#ifdef __CUDACC__
#include <cuda_fp16.h>
#endif
namespace mxnet {
namespace op {
namespace mshadow_op {
#ifdef __CUDA_ARCH__
__constant__ const float PI = 3.14159265358979323846;
__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
__constant__ const float SQRT_2 = 1.4142135623730950488016887242096;
#else
const float PI = 3.14159265358979323846;
const float SELU_ALPHA = 1.6732632423543772848170429916717;
const float SELU_LAMBDA = 1.0507009873554804934193349852946;
const float SQRT_2 = 1.4142135623730950488016887242096;
using std::isnan;
#endif
using std::enable_if;
using std::is_unsigned;
using std::is_integral;
#define MXNET_UNARY_MATH_OP(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
MSHADOW_XINLINE static DType Map(DType a) { \
return DType(expr); \
} \
}
#define MXNET_UNARY_MATH_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
MSHADOW_XINLINE static DType Map(DType a) { \
return (expr); \
} \
}
#define MXNET_UNARY_LOGIC_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
MSHADOW_XINLINE static bool Map(DType a) { \
return (expr); \
} \
}
#define MXNET_BINARY_MATH_OP(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
MSHADOW_XINLINE static DType Map(DType a, DType b) { \
return DType(expr); \
} \
}
#define MXNET_BINARY_MATH_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
MSHADOW_XINLINE static DType Map(DType a, DType b) { \
return (expr); \
} \
}
#define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
MSHADOW_XINLINE static bool Map(DType a, DType b) { \
return (expr); \
} \
}
#define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a))
#define MXNET_SIMPLE_BINARY_MATH_OP(name) MXNET_BINARY_MATH_OP(name, math::name(a, b))
MXNET_UNARY_MATH_OP_NC(identity, a);
MXNET_UNARY_MATH_OP(identity_grad, 1);
struct identity_with_cast {
template<typename DTypeIn, typename DTypeOut>
MSHADOW_XINLINE static void Map(index_t i, DTypeOut *out, DTypeIn *in) {
out[i] = DTypeOut(in[i]);
}
};
MXNET_BINARY_MATH_OP_NC(left, a);
MXNET_BINARY_MATH_OP_NC(right, b);
MXNET_BINARY_MATH_OP_NC(mul, a * b);
MXNET_BINARY_MATH_OP_NC(div, a / b);
MXNET_BINARY_MATH_OP_NC(plus, a + b);
MXNET_BINARY_MATH_OP_NC(minus, a - b);
MXNET_UNARY_MATH_OP(negation, -a);
MXNET_UNARY_MATH_OP(reciprocal, 1.0f / math::id(a));
MXNET_UNARY_MATH_OP(reciprocal_grad, -1.0f / math::sqr(a));
MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a)));
MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));
MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));
MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));
MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));
MXNET_UNARY_MATH_OP_NC(selu_grad,
DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a)));
MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);
MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
DType(static_cast<float>(a) * static_cast<float>(b)));
MXNET_BINARY_MATH_OP_NC(xelu_grad, a > DType(0) ? DType(1) : b);
MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a :
DType(math::id(b) * math::expm1(a)));
MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a));
MXNET_SIMPLE_UNARY_MATH_OP(tanh);
MXNET_UNARY_MATH_OP(tanh_grad, 1.0f - math::sqr(a));
/*! \brief SoftReLU, also known as softplus activation */
struct softrelu : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// Avoid overflow of exp for large inputs.
// Thresholds 20.0 is chosen such that softrelu(a) = a
// for a > 20 using floating precision
if (a > DType(20.0f)) {
return a;
} else {
return DType(math::log1p(math::exp(a)));
}
}
};
MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(erfinv::Map(a))));
MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a)));
MXNET_SIMPLE_UNARY_MATH_OP(erf);
MXNET_UNARY_MATH_OP(gelu,
DType(0.5f * static_cast<float>(a) * (1.0f + math::erf(static_cast<float>(a) / SQRT_2))));
MXNET_BINARY_MATH_OP_NC(gelu_grad,
DType(0.5f * (1.0f + math::erf(static_cast<float>(a) / SQRT_2) +
static_cast<float>(a) * erf_grad::Map(static_cast<float>(a) / SQRT_2) / SQRT_2)));
MXNET_SIMPLE_UNARY_MATH_OP(exp);
MXNET_SIMPLE_UNARY_MATH_OP(expm1);
MXNET_SIMPLE_UNARY_MATH_OP(log);
MXNET_UNARY_MATH_OP(log_grad, 1.0f / math::id(a));
MXNET_SIMPLE_UNARY_MATH_OP(log10);
// Constant is 1 / log(10)
struct log10_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(0.4342944819f / static_cast<float>(a));
}
};
template<>
MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
return 0.43429448190325182765 / a;
}
MXNET_SIMPLE_UNARY_MATH_OP(log2);
// Constant is 1 / log(2)
struct log2_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1.442695041f / static_cast<float>(a));
}
};
template<>
MSHADOW_XINLINE double log2_grad::Map<double>(double a) {
return 1.44269504088896340737 / a;
}
MXNET_SIMPLE_UNARY_MATH_OP(sin);
MXNET_UNARY_MATH_OP(sin_grad, math::cos(a));
MXNET_SIMPLE_UNARY_MATH_OP(log1p);
MXNET_UNARY_MATH_OP(log1p_grad, 1.0f / (1.0f + math::id(a)));
MXNET_SIMPLE_UNARY_MATH_OP(cos);
MXNET_UNARY_MATH_OP(cos_grad, -math::sin(a));
MXNET_SIMPLE_UNARY_MATH_OP(tan);
MXNET_UNARY_MATH_OP(tan_grad, math::sqr(a) + 1.0f);
MXNET_UNARY_MATH_OP(arcsin, math::asin(a));
MXNET_UNARY_MATH_OP(arcsin_grad, 1.0f / math::sqrt(1.0f - math::sqr(a)));
MXNET_UNARY_MATH_OP(arccos, math::acos(a));
MXNET_UNARY_MATH_OP(arccos_grad, -1.0f / math::sqrt(1.0f - math::sqr(a)));
MXNET_UNARY_MATH_OP(arctan, math::atan(a));
MXNET_UNARY_MATH_OP(arctan_grad, 1.0f / (math::sqr(a) + 1.0f));
MXNET_SIMPLE_BINARY_MATH_OP(hypot);
MXNET_BINARY_MATH_OP(hypot_grad_left, math::id(a) / math::hypot(a, b));
MXNET_BINARY_MATH_OP(hypot_grad_right, math::id(b) / math::hypot(a, b));
MXNET_UNARY_MATH_OP(degrees, 180.0f / PI * math::id(a));
MXNET_UNARY_MATH_OP(degrees_grad, 180.0f / PI);
MXNET_UNARY_MATH_OP(radians, PI / 180.0f * math::id(a));
MXNET_UNARY_MATH_OP(radians_grad, PI / 180.0f);
MXNET_SIMPLE_UNARY_MATH_OP(sinh);
MXNET_UNARY_MATH_OP(sinh_grad, math::cosh(a));
MXNET_SIMPLE_UNARY_MATH_OP(cosh);
MXNET_UNARY_MATH_OP(cosh_grad, math::sinh(a));
MXNET_UNARY_MATH_OP(arcsinh, math::asinh(a));
MXNET_UNARY_MATH_OP(arcsinh_grad, 1.0f / math::hypot(a, DType(1)));
MXNET_UNARY_MATH_OP(arccosh, math::acosh(a));
MXNET_UNARY_MATH_OP(arccosh_grad, 1.0f / math::sqrt(math::sqr(a) - 1.0f));
MXNET_UNARY_MATH_OP(arctanh, math::atanh(a));
MXNET_UNARY_MATH_OP(arctanh_grad, 1.0f / (1.0f - math::sqr(a)));
MXNET_UNARY_MATH_OP(square, math::sqr(a));
MXNET_UNARY_MATH_OP(square_grad, 2.0f * math::id(a));
/*! \brief used for generate Bernoulli mask */
MXNET_BINARY_MATH_OP_NC(threshold, a < b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(threshold_eq, a <= b ? DType(1) : DType(0));
/*! \brief used for generate element of abs */
MXNET_UNARY_MATH_OP(abs, math::fabs(a)); // NOLINT(*)
/*! \brief used for generate element of sign */
struct sign : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Map(DType a) {
if (a < DType(0)) return DType(-DType(1));
if (a > DType(0)) return DType(1);
return DType(0);
}
template<typename DType>
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
Map(DType a) {
if (a > DType(0)) return DType(1);
return DType(0);
}
};
MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
/*! \brief used for generate element of power */
MXNET_BINARY_MATH_OP(power, math::pow(a, b));
MXNET_BINARY_MATH_OP(power_grad, math::pow(a, b - DType(1)) * math::id(b));
MXNET_BINARY_MATH_OP(power_rgrad, math::pow(a, b) * math::log(a));
MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));
MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));
MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b));
MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b)));
MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b)));
MXNET_BINARY_MATH_OP(rarctan2, math::atan2(b, a));
MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b)));
MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1));
MXNET_UNARY_LOGIC_OP_NC(np_logical_not, !static_cast<bool>(a));
MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(gt, a > b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(lt, a < b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(le, a <= b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0));
MXNET_BINARY_LOGIC_OP_NC(np_greater_equal, a >= b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_greater, a > b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_less, a < b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_less_equal, a <= b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_equal, a == b ? true : false);
MXNET_BINARY_LOGIC_OP_NC(np_not_equal, a != b ? true : false);
MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>(b));
MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));
MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
MXNET_UNARY_MATH_OP(reciprocal_square_root, 1.0f / math::sqrt(a));
MXNET_UNARY_MATH_OP(reciprocal_square_root_grad, -0.5f / (math::sqrt(a) * math::id(a)));
MXNET_UNARY_MATH_OP(cube_root, math::cbrt(a));
MXNET_UNARY_MATH_OP(cube_root_grad, 1.0f / (3.0f * math::sqr(a)));
MXNET_UNARY_MATH_OP(reciprocal_cube_root, 1.0f / math::cbrt(a));
MXNET_UNARY_MATH_OP(reciprocal_cube_root_grad, -1.0f / (3.0f * math::cbrt(a) * math::id(a)));
/*! \brief used for generate element of ldexp */
MXNET_BINARY_MATH_OP(ldexp, math::id(a) * math::pow(2.0f, b));
MXNET_BINARY_MATH_OP(ldexp_grad, math::pow(2.0f, b));
MXNET_BINARY_MATH_OP(ldexp_rgrad, math::id(a) * math::pow(2.0f, b) * math::log(2.0f));
MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a)); // swap a and b if a is scalar.
MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));
/*! \brief used for generate element of round */
MXNET_SIMPLE_UNARY_MATH_OP(round);
/*! \brief used for generate element of ceil */
MXNET_SIMPLE_UNARY_MATH_OP(ceil);
/*! \brief used for generate element of floor */
MXNET_SIMPLE_UNARY_MATH_OP(floor);
/*! \brief used to round towards zero */
MXNET_SIMPLE_UNARY_MATH_OP(trunc);
/*! \brief used to round number to nearest integer */
struct rint : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
auto floor = math::floor(a);
auto ceil = math::ceil(a);
auto af = math::id(a);
return DType((af - floor) <= (ceil - af) ? floor : ceil);
}
};
/*! \brief used to round number to integer nearest to 0 */
struct fix : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
auto floor = math::floor(a);
auto ceil = math::ceil(a);
return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil);
}
};
/*! \brief used for generate gradient of MAE loss*/
MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));
MXNET_BINARY_MATH_OP(rminus, b - a);
MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b));
template<>
MSHADOW_XINLINE mshadow::half::half2_t div_grad::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
mshadow::half::half2_t b) {
return mshadow::half::half2_t(1) / b;
}
MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b));
template<>
MSHADOW_XINLINE mshadow::half::half2_t div_rgrad::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
mshadow::half::half2_t b) {
return -a / (b * b);
}
MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));
MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));
MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a);
MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1);
MXNET_BINARY_MATH_OP(copysign_rgrad, 0);
MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b);
MXNET_BINARY_MATH_OP(rcopysign_grad, 0);
struct mod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Map(DType a, DType b) {
if (b == DType(0)) {
return DType(0);
} else if (b < DType(0)) {
if (a < DType(0)) {
return DType(-::fmod(-static_cast<double>(a), -static_cast<double>(b)));
} else {
return DType(::fmod(static_cast<double>(a), -static_cast<double>(b)) +
(::fmod(static_cast<double>(a), -static_cast<double>(b)) != DType(0)
? b : DType(0)));
}
} else {
if (a < DType(0)) {
return DType(-::fmod(-static_cast<double>(a), static_cast<double>(b)) +
(::fmod(-static_cast<double>(a), static_cast<double>(b)) != DType(0)
? b : DType(0)));
} else {
return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
}
}
}
template<typename DType>
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
Map(DType a, DType b) {
if (b == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
}
}
};
template<>
MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
mshadow::half::half2_t b) {
return a%b;
}
struct mod_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(0);
}
};
template<>
MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) {
return 1.0;
}
template<>
MSHADOW_XINLINE float mod_grad::Map<float>(float a, float b) {
return 1.0f;
}
template<>
MSHADOW_XINLINE mshadow::half::half_t mod_grad::Map<mshadow::half::half_t>
(mshadow::half::half_t a,
mshadow::half::half_t b) {
return mshadow::half::half_t(1.0f);
}
template<>
MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
mshadow::half::half2_t b) {
mshadow::half::half2_t result = mshadow::half::half2_t();
#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
result.half2_ = ::__float2half2_rn(1.0f);
#else
result.half_t2[0] = mshadow::half::half_t(0.0f);
result.half_t2[1] = mshadow::half::half_t(1.0f);
#endif
return result;
}
struct mod_rgrad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(0);
}
};
template<>
MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) {
return -::floor(a/b);
}
template<>
MSHADOW_XINLINE float mod_rgrad::Map<float>(float a, float b) {
return -::floorf(a/b);
}
template<>
MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map<mshadow::half::half_t>
(mshadow::half::half_t a,
mshadow::half::half_t b) {
return mshadow::half::half_t(-::floorf(static_cast<float>(a/b)));
}
template<>
MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
mshadow::half::half2_t b) {
#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
return mshadow::half::half2_t(__hneg2(::h2floor((a/b).half2_)));
#else
return mshadow::half::half2_t(mshadow::half::half_t(-::floorf(
static_cast<float>(a.half_t2[0]/b.half_t2[0]))),
mshadow::half::half_t(-::floorf(
static_cast<float>(a.half_t2[1]/b.half_t2[1]))));
#endif
}
struct rmod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Map(DType a, DType b) {
if (a == DType(0)) {
return DType(0);
} else if (a < DType(0)) {
if (b < DType(0)) {
return DType(-::fmod(-static_cast<double>(b), -static_cast<double>(a)));
} else {
return DType(::fmod(static_cast<double>(b), -static_cast<double>(a)) +
(::fmod(static_cast<double>(b), -static_cast<double>(a)) != DType(0)
? a : DType(0)));
}
} else {
if (b < DType(0)) {
return DType(-::fmod(-static_cast<double>(b), static_cast<double>(a)) +
(::fmod(-static_cast<double>(b), static_cast<double>(a)) != DType(0)
? a : DType(0)));
} else {
return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
}
}
}
template<typename DType>
MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
Map(DType a, DType b) {
if (a == DType(0)) {
return DType(0);
} else {
return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
}
}
};
template<>
MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
mshadow::half::half2_t b) {
return b%a;
}
struct rmod_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(0);
}
};
template<>
MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) {
return -::floor(b/a);
}
template<>
MSHADOW_XINLINE float rmod_grad::Map<float>(float a, float b) {
return -::floorf(b/a);
}
template<>
MSHADOW_XINLINE mshadow::half::half_t rmod_grad::Map<mshadow::half::half_t>
(mshadow::half::half_t a,
mshadow::half::half_t b) {
return mshadow::half::half_t(-::floorf(static_cast<float>(b/a)));
}
template<>
MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map<mshadow::half::half2_t>
(mshadow::half::half2_t a,
mshadow::half::half2_t b) {
#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
return mshadow::half::half2_t(::__hneg2(::h2floor((b/a).half2_)));
#else
return mshadow::half::half2_t(mshadow::half::half_t(-::floorf(
static_cast<float>(b.half_t2[0]/a.half_t2[0]))),
mshadow::half::half_t(-::floorf(
static_cast<float>(b.half_t2[1]/a.half_t2[1]))));
#endif
}
struct clip : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType bound) {
if (x > bound) {
return bound;
} else if (x < -bound) {
return -bound;
} else {
return x;
}
}
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) {
if (x > upper_bound) {
return upper_bound;
} else if (x < lower_bound) {
return lower_bound;
}
return x;
}
};
/***** gamma ******/
MXNET_UNARY_MATH_OP(gamma, math::tgamma(a));
struct gamma_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
float af(static_cast<float>(a));
return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af));
}
};
template<>
MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
return math::tgamma(a) * special_functions::cephes::psi<double>(a);
}
/***** gammaln ******/
MXNET_UNARY_MATH_OP(gammaln, math::lgamma(a));
struct gammaln_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(special_functions::cephes::psi<float>(a));
}
};
template<>
MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
return special_functions::cephes::psi<double>(a);
}
/* Smooth L1 Loss is a loss specific for R-CNN franchise training
* Smooth L1 Loss function:
* f(x) = 0.5 * (sigma * x) ^ 2, |x| < 1 / sigma^2
* = |x| - 0.5 / sigma / sigma, otherwise
* When sigma = 1, it is equivalent to the Huber loss, evaluated at
* delta = 1.
* smooth_l1_loss = w_out * f(w_in * x)
* with w_in, w_out provided by input_data.
*/
struct smooth_l1_loss : public mxnet_op::tunable {
// a is x, b is sigma
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
auto bsq = math::sqr(b);
auto ibsq = 1.0f / bsq;
auto af = math::id(a);
if (af > ibsq) {
return DType(af - 0.5f * ibsq);
} else if (af < -ibsq) {
return DType(-af - 0.5f * ibsq);
} else {
return DType(0.5f * af * af * bsq);
}
}
}; // struct smooth_l1_loss
/* The derivative of smooth l1 loss is
* f'(x) = sigma^2 * x, |x| < 1 / sigma^2
* = sign(x), otherwise
*/
struct smooth_l1_gradient : public mxnet_op::tunable {
// a is x, b is sigma2
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
auto bsq = math::sqr(b);
auto ibsq = 1.0f / bsq;
auto af = math::id(a);
if (af > ibsq) {
return DType(1);
} else if (af < -ibsq) {
return DType(-1);
} else {
return DType(bsq * af);
}
}
}; // struct smooth_l1_derivative
/*! \brief product reducer */
struct product {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
dst *= src;
}
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redres / redsrc;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = 1;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
SetInitValue(initv);
}
};
namespace isnan_typed {
template<typename DType>
MSHADOW_XINLINE bool IsNan(volatile DType val) {
return false;
}
template<>
MSHADOW_XINLINE bool IsNan(volatile float val) {
return isnan(val);
}
template<>
MSHADOW_XINLINE bool IsNan(volatile double val) {
return isnan(val);
}
template<>
MSHADOW_XINLINE bool IsNan(volatile long double val) {
return isnan(val);
}
template<>
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
return (val.half_ & 0x7fff) > 0x7c00;
}
}; // namespace isnan_typed
MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0));
/*! \brief used for computing gradient of relu operator */
struct relu_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return a > DType(0) ? DType(1) : DType(0);
}
}
};
/*! \brief used for computing binary operator maximum */
struct maximum : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return (a > b ? a : b);
}
}
};
/*! \brief used for computing binary operator minimum */
struct minimum : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return DType(a < b ? a : b);
}
}
};
/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
if (isnan_typed::IsNan(src)) return;
dst += src;
}
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
if (isnan_typed::IsNan(src)) return;
DType y = src - residual;
DType t = dst + y;
residual = (t - dst) - y;
dst = t;
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
DType t1 = dst_val + src_val;
DType e = t1 - src_val;
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
dst_val = t1 + t2;
dst_residual = t2 - (dst_val - t1);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*)
initv = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
SetInitValue(initv);
residual = 0;
}
};
struct nansum_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return isnan_typed::IsNan(a) ? DType(0) : DType(1);
}
};
/*! \brief product reducer that ignores NaN values in the input */
struct nanprod {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
if (isnan_typed::IsNan(src)) return;
dst *= src;
}
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*)
initv = 1;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
SetInitValue(initv);
}
};
/*! \brief compute l2 norm */
struct nrm2 {
/*! \brief do reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src) { // NOLINT(*)
sum_of_squares += src * src;
}
/*! \brief do stable reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*)
if (src != 0) {
DType abs = mshadow_op::abs::Map(src);
if (scale < abs) {
sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
scale = abs;
} else {
sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
}
}
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
dst_val += src_val;
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, volatile DType& src_ssq, volatile DType& src_scale) { // NOLINT(*)
if (dst_scale != 0 && dst_scale >= src_scale) {
dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale);
} else if (src_scale != 0 && dst_scale < src_scale) {
dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale);
dst_scale = src_scale;
}
}
/*! \brief finalize reduction result */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*)
sum_of_squares = math::sqrt(sum_of_squares);
}
/*! \brief finalize reduction result */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { // NOLINT(*)
sum_of_squares = scale * math::sqrt(sum_of_squares);
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redsrc / redres;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares) { // NOLINT(*)
sum_of_squares = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares, DType &scale) { // NOLINT(*)
SetInitValue(sum_of_squares);
scale = 0;
}
};
/*! \brief sum reducer */
struct sum {
/*! \brief do reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*)
dst += src;
}
/*! \brief do stable reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
DType y = src - residual;
DType t = dst + y;
residual = (t - dst) - y;
dst = t;
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
DType t1 = dst_val + src_val;
DType e = t1 - dst_val;
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
dst_val = t1 + t2;
dst_residual = t2 - (dst_val - t1);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
SetInitValue(initv);
residual = 0;
}
};
struct nanprod_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return isnan_typed::IsNan(a) ? DType(0) : b / a;
}
};
/*! \brief used for computing binary lowest common multiple */
struct lcm : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type
Map(DType a, DType b) {
// minus cases.
if (a < 0) {
a = -a;
}
if (b < 0) {
b = -b;
}
// handle zero-valued cases.
DType c;
if (a == 0 || b == 0) {
c = 0;
} else {
DType tmp;
DType tmp_a = a;
DType tmp_b = b;
if (a < b) {
tmp = a;
a = b;
b = tmp;
}
while (a % b != 0) {
a = a % b;
tmp = a;
a = b;
b = tmp;
}
c = tmp_a / b * tmp_b;
}
return c;
}
template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type
Map(DType a, DType b) {
return DType(0.0f);
}
};
} // namespace mshadow_op
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_MSHADOW_OP_H_