| /* | 
 |  * Licensed to the Apache Software Foundation (ASF) under one | 
 |  * or more contributor license agreements.  See the NOTICE file | 
 |  * distributed with this work for additional information | 
 |  * regarding copyright ownership.  The ASF licenses this file | 
 |  * to you under the Apache License, Version 2.0 (the | 
 |  * "License"); you may not use this file except in compliance | 
 |  * with the License.  You may obtain a copy of the License at | 
 |  * | 
 |  *   http://www.apache.org/licenses/LICENSE-2.0 | 
 |  * | 
 |  * Unless required by applicable law or agreed to in writing, | 
 |  * software distributed under the License is distributed on an | 
 |  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
 |  * KIND, either express or implied.  See the License for the | 
 |  * specific language governing permissions and limitations | 
 |  * under the License. | 
 |  */ | 
 |  | 
 | #ifndef MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ | 
 | #define MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ | 
 |  | 
 | #include <string> | 
 | #include <map> | 
 | #include <vector> | 
 |  | 
 | #if MXNET_USE_CUDA | 
 |  | 
 | namespace mxnet { | 
 |  | 
 | namespace fusion { | 
 |  | 
 | const char fp16_support_string[] = R"code( | 
 | struct __align__(2) __half { | 
 |   __host__ __device__ __half() { } | 
 |   unsigned short __x; | 
 | }; | 
 | /* Definitions of intrinsics */ | 
 | __device__ inline __half __float2half(const float f) { | 
 |   __half val; | 
 |  asm("{  cvt.rn.f16.f32 %0, %1;}\n" : "=h"(val.__x) : "f"(f)); | 
 |   return val; | 
 | } | 
 | __device__ inline float __half2float(const __half h) { | 
 |   float val; | 
 |  asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(h.__x)); | 
 |   return val; | 
 | } | 
 |  | 
 | typedef __half half; | 
 | )code"; | 
 |  | 
 | const char type_support_string[] = R"code( | 
 | using float32 = float; | 
 | using float64 = double; | 
 | using float16 = half; | 
 | using uint8 = unsigned char; | 
 | using int8 = char; | 
 | using int32 = int; | 
 | using int64 = long long; | 
 | )code"; | 
 |  | 
 | const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = { | 
 |   {"elemwise_add"                      , {{"op::add(%, %)", "_0", "_1"}}}, | 
 |   {"_plus"                             , {{"op::add(%, %)", "_0", "_1"}}}, | 
 |   {"_Plus"                             , {{"op::add(%, %)", "_0", "_1"}}}, | 
 |   {"_add"                              , {{"op::add(%, %)", "_0", "_1"}}}, | 
 |   {"elemwise_sub"                      , {{"op::sub(%, %)", "_0", "_1"}}}, | 
 |   {"_minus"                            , {{"op::sub(%, %)", "_0", "_1"}}}, | 
 |   {"_Minus"                            , {{"op::sub(%, %)", "_0", "_1"}}}, | 
 |   {"_sub"                              , {{"op::sub(%, %)", "_0", "_1"}}}, | 
 |   {"elemwise_mul"                      , {{"op::mul(%, %)", "_0", "_1"}}}, | 
 |   {"_mul"                              , {{"op::mul(%, %)", "_0", "_1"}}}, | 
 |   {"_Mul"                              , {{"op::mul(%, %)", "_0", "_1"}}}, | 
 |   {"elemwise_div"                      , {{"op::div(%, %)", "_0", "_1"}}}, | 
 |   {"_div"                              , {{"op::div(%, %)", "_0", "_1"}}}, | 
 |   {"_Div"                              , {{"op::div(%, %)", "_0", "_1"}}}, | 
 |   {"_Power"                            , {{"op::power(%, %)", "_0", "_1"}}}, | 
 |   {"_power"                            , {{"op::power(%, %)", "_0", "_1"}}}, | 
 |   {"_Maximum"                          , {{"op::max(%, %)", "_0", "_1"}}}, | 
 |   {"_maximum"                          , {{"op::max(%, %)", "_0", "_1"}}}, | 
 |   {"_Minimum"                          , {{"op::min(%, %)", "_0", "_1"}}}, | 
 |   {"_minimum"                          , {{"op::min(%, %)", "_0", "_1"}}}, | 
 |   {"amp_cast"                          , {{"op::identity(%)", "_0"}}}, | 
 |   {"_backward_amp_cast"                , {{"op::identity(%)", "_0"}}}, | 
 |   {"relu"                              , {{"op::relu(%)", "_0"}}}, | 
 |   {"sigmoid"                           , {{"op::sigmoid(%)", "_0"}}}, | 
 |   {"softsign"                          , {{"op::softsign(%)", "_0"}}}, | 
 |   {"exp"                               , {{"op::exp(%)", "_0"}}}, | 
 |   {"expm1"                             , {{"op::expm1(%)", "_0"}}}, | 
 |   {"log"                               , {{"op::log(%)", "_0"}}}, | 
 |   {"log10"                             , {{"op::log10(%)", "_0"}}}, | 
 |   {"log2"                              , {{"op::log2(%)", "_0"}}}, | 
 |   {"log1p"                             , {{"op::log1p(%)", "_0"}}}, | 
 |   {"degrees"                           , {{"op::degrees(%)", "_0"}}}, | 
 |   {"radians"                           , {{"op::radians(%)", "_0"}}}, | 
 |   {"sin"                               , {{"op::sin(%)", "_0"}}}, | 
 |   {"cos"                               , {{"op::cos(%)", "_0"}}}, | 
 |   {"tan"                               , {{"op::tan(%)", "_0"}}}, | 
 |   {"arcsin"                            , {{"op::arcsin(%)", "_0"}}}, | 
 |   {"arccos"                            , {{"op::arccos(%)", "_0"}}}, | 
 |   {"arctan"                            , {{"op::arctan(%)", "_0"}}}, | 
 |   {"sinh"                              , {{"op::sinh(%)", "_0"}}}, | 
 |   {"cosh"                              , {{"op::cosh(%)", "_0"}}}, | 
 |   {"tanh"                              , {{"op::tanh(%)", "_0"}}}, | 
 |   {"arcsinh"                           , {{"op::arcsinh(%)", "_0"}}}, | 
 |   {"arccosh"                           , {{"op::arccosh(%)", "_0"}}}, | 
 |   {"arctanh"                           , {{"op::arctanh(%)", "_0"}}}, | 
 |   {"sqrt"                              , {{"op::sqrt(%)", "_0"}}}, | 
 |   {"rsqrt"                             , {{"op::rsqrt(%)", "_0"}}}, | 
 |   {"cbrt"                              , {{"op::cbrt(%)", "_0"}}}, | 
 |   {"rcbrt"                             , {{"op::rcbrt(%)", "_0"}}}, | 
 |   {"square"                            , {{"op::square(%)", "_0"}}}, | 
 |   {"squeeze"                           , {{"op::identity(%)", "_0"}}}, | 
 |   {"zeros_like"                        , {{"op::zero(%)", "_0"}}}, | 
 |   {"ones_like"                         , {{"op::one(%)", "_0"}}}, | 
 |   {"flatten"                           , {{"op::identity(%)", "_0"}}}, | 
 |   {"Reshape"                           , {{"op::identity(%)", "_0"}}}, | 
 |   {"reshape"                           , {{"op::identity(%)", "_0"}}}, | 
 |   {"_backward_reshape"                 , {{"op::identity(%)", "_0"}}}, | 
 |   {"expand_dims"                       , {{"op::identity(%)", "_0"}}}, | 
 |   {"round"                             , {{"op::round(%)", "_0"}}}, | 
 |   {"rint"                              , {{"op::rint(%)", "_0"}}}, | 
 |   {"fix"                               , {{"op::fix(%)", "_0"}}}, | 
 |   {"floor"                             , {{"op::floor(%)", "_0"}}}, | 
 |   {"ceil"                              , {{"op::ceil(%)", "_0"}}}, | 
 |   {"trunc"                             , {{"op::trunc(%)", "_0"}}}, | 
 |   {"sign"                              , {{"op::sign(%)", "_0"}}}, | 
 |   {"reciprocal"                        , {{"op::reciprocal(%)", "_0"}}}, | 
 |   {"abs"                               , {{"op::abs(%)", "_0"}}}, | 
 |   {"gamma"                             , {{"op::gamma(%)", "_0"}}}, | 
 |   {"gammaln"                           , {{"op::gammaln(%)", "_0"}}}, | 
 |   {"erf"                               , {{"op::erf(%)", "_0"}}}, | 
 |   {"erfinv"                            , {{"op::erfinv(%)", "_0"}}}, | 
 |   {"_copy"                             , {{"op::identity(%)", "_0"}}}, | 
 |   {"_identity_with_attr_like_rhs"      , {{"op::identity(%)", "_0"}}}, | 
 |   {"_plus_scalar"                      , {{"op::add(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_PlusScalar"                       , {{"op::add(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_minus_scalar"                     , {{"op::sub(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_MinusScalar"                      , {{"op::sub(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_rminus_scalar"                    , {{"(-op::sub(%, float(%)))", "_0", "scalar"}}}, | 
 |   {"_RMinusScalar"                     , {{"(-op::sub(%, float(%)))", "_0", "scalar"}}}, | 
 |   {"_mul_scalar"                       , {{"op::mul(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_MulScalar"                        , {{"op::mul(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_div_scalar"                       , {{"op::mul(%, 1.0f/float(%))", "_0", "scalar"}}}, | 
 |   {"_DivScalar"                        , {{"op::mul(%, 1.0f/float(%))", "_0", "scalar"}}}, | 
 |   {"_rdiv_scalar"                      , {{"op::rdiv(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_power_scalar"                     , {{"op::power(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_PowerScalar"                      , {{"op::power(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_rpower_scalar"                    , {{"op::rpow(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_RPowerScalar"                     , {{"op::rpow(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_RDivScalar"                       , {{"op::rdiv(%, float(%))", "_0", "scalar"}}}, | 
 |   {"Cast"                              , {{"op::cast<%>(%)", "dtype", "_0"}}}, | 
 |   {"cast"                              , {{"op::cast<%>(%)", "dtype", "_0"}}}, | 
 |   {"Activation"                        , {{"op::%(%)", "act_type", "_0"}}}, | 
 |   {"clip"                              , {{"op::clip(%, %, %)", "_0", "a_min", "a_max"}}}, | 
 |   {"_zeros"                            , {{"op::zero<%>()", "dtype"}}}, | 
 |   {"_ones"                             , {{"op::one<%>()", "dtype"}}}, | 
 |   {"negative"                          , {{"(-%)", "_0"}}}, | 
 |   {"_hypot"                            , {{"op::hypot(%, %)", "_0", "_1"}}}, | 
 |   {"_hypot_scalar"                     , {{"op::hypot(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_backward_relu"                    , {{"op::backward_relu(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_sigmoid"                 , {{"op::backward_sigmoid(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_expm1"                   , {{"op::backward_expm1(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_log"                     , {{"op::backward_log(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_log10"                   , {{"op::backward_log10(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_log2"                    , {{"op::backward_log2(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_log1p"                   , {{"op::backward_log1p(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_sin"                     , {{"op::backward_sin(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_cos"                     , {{"op::backward_cos(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_tan"                     , {{"op::backward_tan(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_arcsin"                  , {{"op::backward_arcsin(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_arccos"                  , {{"op::backward_arccos(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_arctan"                  , {{"op::backward_arctan(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_sinh"                    , {{"op::backward_sinh(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_cosh"                    , {{"op::backward_cosh(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_tanh"                    , {{"op::backward_tanh(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_arcsinh"                 , {{"op::backward_arcsinh(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_arccosh"                 , {{"op::backward_arccosh(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_arctanh"                 , {{"op::backward_arctanh(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_sqrt"                    , {{"op::backward_sqrt(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_rsqrt"                   , {{"op::backward_rsqrt(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_cbrt"                    , {{"op::backward_cbrt(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_rcbrt"                   , {{"op::backward_rcbrt(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_square"                  , {{"op::backward_square(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_div_scalar"              , {{"(% * 1.0f/float(%))", "_0", "scalar"}}}, | 
 |   {"_backward_div_scalar"              , {{"(% * 1.0f/float(%))", "_0", "scalar"}}}, | 
 |   {"_backward_rdiv_scalar"             , {{"(-% * float(%) / (% * %))", "_0", | 
 |                                            "scalar", "_1", "_1"}}}, | 
 |   {"_backward_hypot_scalar"            , {{"(% * % / op::hypot(%, float(%)))", | 
 |                                            "_0", "_1", "_1", "scalar"}}}, | 
 |   {"_backward_radians"                 , {{"op::radians(%)", "_0"}}}, | 
 |   {"_backward_erf"                     , {{"op::backward_erf(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_erfinv"                  , {{"op::backward_erfinv(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_reciprocal"              , {{"op::backward_reciprocal(%, %)", "_1", "_0"}}}, | 
 |   {"_backward_abs"                     , {{"(% * op::sign(%))", "_0", "_1"}}}, | 
 |   {"_backward_degrees"                 , {{"op::degrees(%)", "_0"}}}, | 
 |   {"_backward_sign"                    , {{"op::zero(%)", "_0"}}}, | 
 |   {"_backward_clip"                    , {{"op::backward_clip(%, %, %, %)", "_1", "_0", | 
 |                                                                             "a_min", "a_max"}}}, | 
 |   {"smooth_l1"                         , {{"op::smooth_l1(%, float(%))", "_0", "scalar"}}}, | 
 |   {"_backward_smooth_l1"               , {{"op::backward_smooth_l1(%, float(%), %)", | 
 |                                            "_1", "scalar", "_0"}}}, | 
 |   // TODO(ptredak): arange | 
 |   // TODO(ptredak): LeakyRelu | 
 |   // TODO(ptredak): mod and rmod | 
 |   {"_backward_sub"                     , {{"(%)", "_0"}, | 
 |                                           {"(-(%))", "_0"}}}, | 
 |   {"_backward_mul"                     , {{"(% * %)", "_0", "_2"}, | 
 |                                           {"(% * %)", "_0", "_1"}}}, | 
 |   {"_backward_mul_scalar"              , {{"(% * float(%))", "_0", "scalar"}}}, | 
 |   {"_backward_div"                     , {{"(% / %)", "_0", "_2"}, | 
 |                                           {"(-% * % / (% * %))", "_0", "_1", "_2", "_2"}}}, | 
 |   {"_backward_power"                   , {{"(% * % * powf(%, % - 1))", "_0", "_2", "_1", "_2"}, | 
 |                                           {"(% * powf(%, %) * logf(%))", "_0", "_1", "_2", "_1"}}}, | 
 |   {"_backward_power_scalar"            , {{"(% * float(%) * powf(%, float(%) - 1))", | 
 |                                            "_0", "scalar", "_1", "scalar"}}}, | 
 |   {"_backward_rpower_scalar"           , {{"(% * % * logf(float(%)))", "_0", "_1", "scalar"}}}, | 
 |   {"_backward_maximum"                 , {{"((% >= %) ? % : 0)", "_1", "_2", "_0"}, | 
 |                                           {"((% >= %) ? 0 : %)", "_1", "_2", "_0"}}}, | 
 |   {"_backward_minimum"                 , {{"((% <= %) ? % : 0)", "_1", "_2", "_0"}, | 
 |                                           {"((% <= %) ? 0 : %)", "_1", "_2", "_0"}}}, | 
 |   {"_backward_hypot"                   , {{"(% * % / op::hypot(%, %))", "_0", "_1", "_1", "_2"}, | 
 |                                           {"(% * % / op::hypot(%, %))", "_0", "_2", "_1", "_2"}}} | 
 | }; | 
 |  | 
 | const std::map<std::string, std::string> slice_ops = { | 
 |   {"slice_axis"   , ""}, | 
 |   {"slice"   , ""}, | 
 |   {"slice_like"   , ""}, | 
 |   {"broadcast_like"   , ""}, | 
 | }; | 
 |  | 
 | const std::vector<std::string> variable_io_ops = { | 
 |   "add_n", | 
 |   "_backward_Activation", | 
 |   "amp_multicast", | 
 |   "_backward_amp_multicast", | 
 |   "_backward_cast" | 
 | }; | 
 |  | 
 | const char function_definitions[] = R"code( | 
 |  | 
 | #define INT_MAX (2147483647) | 
 |  | 
 | namespace op { | 
 |  | 
 | template <typename DType> | 
 | struct LoadType { | 
 |   using Type = DType; | 
 | }; | 
 |  | 
 | template <> | 
 | struct LoadType<half> { | 
 |   using Type = float; | 
 | }; | 
 |  | 
 | template <typename DType> | 
 | inline typename LoadType<DType>::Type load(const DType input) { | 
 |   return input; | 
 | } | 
 |  | 
 | template <> | 
 | inline float load(const half input) { | 
 |   return __half2float(input); | 
 | } | 
 |  | 
 | template <typename DType1, typename DType2> | 
 | inline DType1 store(const DType2 input, DType1* ref) { | 
 |   return input; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline half store(const DType input, half* ref) { | 
 |   return __float2half(input); | 
 | } | 
 |  | 
 | template <int size> | 
 | struct VectorConfig { | 
 |     static_assert(size >= 4, "VectorConfig needs to have size of at least 4B"); | 
 |     using IndexType = float; | 
 | }; | 
 |  | 
 | template <> | 
 | struct VectorConfig<8> { | 
 |     using IndexType = double; | 
 | }; | 
 |  | 
 | template <> | 
 | struct VectorConfig<16> { | 
 |     using IndexType = double2; | 
 | }; | 
 |  | 
 | template <> | 
 | struct VectorConfig<32> { | 
 |     using IndexType = double4; | 
 | }; | 
 |  | 
 | template <typename DType> | 
 | inline DType add_elem(const DType& x, const DType& y) { | 
 |   return x + y; | 
 | } | 
 |  | 
 | template <> | 
 | inline half add_elem(const half& x, const half& y) { | 
 |   return __float2half(__half2float(x) + __half2float(y)); | 
 | } | 
 |  | 
 | template <typename DType, int nvec> | 
 | union VectorType { | 
 |     typename VectorConfig<sizeof(DType)*nvec>::IndexType y; | 
 |     DType x[nvec]; | 
 |     VectorType () {}; | 
 |     VectorType (const VectorType<DType, nvec>& y2) { | 
 |         y = y2.y; | 
 |     } | 
 |     VectorType (const decltype(y) &y2) { | 
 |         y = y2; | 
 |     } | 
 |     inline VectorType<DType, nvec>& operator+=(const VectorType<DType, nvec>& rhs) { | 
 |       #pragma unroll | 
 |       for (int i = 0; i < nvec; ++i) { | 
 |         x[i] = add_elem(x[i], rhs.x[i]); | 
 |       } | 
 |       return *this; | 
 |     } | 
 | }; | 
 |  | 
 | template <int ndim> | 
 | struct Shape { | 
 |    int x[ndim]; | 
 |    size_t size; | 
 |    inline const int& operator [](const int i) const { | 
 |        return x[i]; | 
 |    } | 
 |    inline int& operator [](const int i) { | 
 |        return x[i]; | 
 |    } | 
 |    inline void set(const int def) { | 
 |        #pragma unroll | 
 |        for (int i = 0; i < ndim; i++) { | 
 |            x[i] = def; | 
 |        } | 
 |    } | 
 | }; | 
 |  | 
 | template <> | 
 | struct Shape<0> { | 
 |    size_t size; | 
 | }; | 
 |  | 
 | template <int nvec, typename DType, int ndim> | 
 | inline VectorType<DType, nvec> load_index(const DType * input, int i, const Shape<ndim> &shape) { | 
 |   if (i < shape.size) { | 
 |     const auto* vector_input = reinterpret_cast< | 
 |                                 const typename VectorConfig<sizeof(DType)*nvec>::IndexType *>( | 
 |                                     input + i); | 
 |     VectorType<DType, nvec> ret = {*vector_input}; | 
 |     return ret; | 
 |   } else { | 
 |     VectorType<DType, nvec> ret({0}); | 
 |     return ret; | 
 |   } | 
 | } | 
 |  | 
 | template <int nvec, typename DType, int ndim> | 
 | inline VectorType<DType, nvec> global_load_index(const DType * input, int i, const Shape<ndim> &shape) { | 
 |   if (i < shape.size) { | 
 |     const auto* vector_input = reinterpret_cast< | 
 |                                 const typename VectorConfig<sizeof(DType)*nvec>::IndexType *>( | 
 |                                     input + i); | 
 |     VectorType<DType, nvec> ret = {__ldg(vector_input)}; | 
 |     return ret; | 
 |   } else { | 
 |     VectorType<DType, nvec> ret({0}); | 
 |     return ret; | 
 |   } | 
 | } | 
 |  | 
 | template <int nvec, typename DType, int ndim> | 
 | inline VectorType<DType, nvec> load_slice(const DType * input, const Shape<ndim>& shape, Shape<ndim> begin, Shape<ndim> end, int offset) { | 
 |   int idx[nvec]; | 
 |  | 
 |   Shape<ndim> ref_strides; | 
 |   Shape<ndim> strides; | 
 |   ref_strides[ndim-1] = 1; | 
 |   strides[ndim-1] = 1; | 
 |   #pragma unroll | 
 |   for (int dim = ndim-1; dim >=0; dim--) { | 
 |     if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim]; | 
 |     if (end[dim] < 0) end[dim] = shape[dim] - end[dim]; | 
 |     if (end[dim] == INT_MAX) end[dim] = shape[dim]; | 
 |     if (dim > 0) { | 
 |       ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); | 
 |       strides[dim-1] = strides[dim] * shape[dim]; | 
 |     } | 
 |   } | 
 |   #pragma unroll | 
 |   for (int j = 0; j < nvec; j++) { | 
 |     idx[j] = 0; | 
 |     int ref_idx = offset + j; | 
 |     #pragma unroll | 
 |     for (int dim = 0; dim < ndim; dim++) { | 
 |        int stride = ref_strides[dim]; | 
 |        if (shape[dim] > 1) { | 
 |          idx[j] += (ref_idx / stride + begin[dim]) * strides[dim]; | 
 |        } | 
 |        ref_idx = ref_idx % stride; | 
 |     } | 
 |   } | 
 |   VectorType<DType, nvec> ret; | 
 |   #pragma unroll | 
 |   for (int j = 0; j < nvec; j++) { | 
 |       ret.x[j] = *(input + idx[j]); | 
 |   } | 
 |   return ret; | 
 | } | 
 |  | 
 | template <int nvec, typename DType, int ndim> | 
 | inline VectorType<DType, nvec> fast_load_slice(const DType * input, const Shape<ndim>& shape, Shape<ndim> begin, Shape<ndim> end, int offset) { | 
 |   int idx = 0; | 
 |  | 
 |   Shape<ndim> ref_strides; | 
 |   Shape<ndim> strides; | 
 |   ref_strides[ndim-1] = 1; | 
 |   strides[ndim-1] = 1; | 
 |   #pragma unroll | 
 |   for (int dim = ndim-1; dim >=0; dim--) { | 
 |     if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim]; | 
 |     if (end[dim] < 0) end[dim] = shape[dim] - end[dim]; | 
 |     if (end[dim] == INT_MAX) end[dim] = shape[dim]; | 
 |     if (dim > 0) { | 
 |       ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]); | 
 |       strides[dim-1] = strides[dim] * shape[dim]; | 
 |     } | 
 |   } | 
 |   int ref_idx = offset; | 
 |   #pragma unroll | 
 |   for (int dim = 0; dim < ndim; dim++) { | 
 |      int stride = ref_strides[dim]; | 
 |      if (shape[dim] > 1) { | 
 |        idx += (ref_idx / stride + begin[dim]) * strides[dim]; | 
 |      } | 
 |      ref_idx = ref_idx % stride; | 
 |   } | 
 |   return global_load_index<nvec>(input, idx, shape); | 
 | } | 
 |  | 
 | template <int nvec, typename DType, int ndim> | 
 | inline void store_index(const VectorType<DType, nvec> value, int i, | 
 |                         DType * output, const Shape<ndim>& shape) { | 
 |   if (i < (shape.size + nvec - 1) / nvec) { | 
 |     auto vector_output = reinterpret_cast< | 
 |                           typename VectorConfig<sizeof(DType)*nvec>::IndexType *>(output); | 
 |     vector_output[i] = value.y; | 
 |   } | 
 | } | 
 |  | 
 | template <int nvec, typename DType, int ndim> | 
 | inline void store_add_index(const VectorType<DType, nvec> value, int i, | 
 |                             DType * output, const Shape<ndim>& shape) { | 
 |   if (i < (shape.size + nvec - 1) / nvec) { | 
 |     auto vector_output = reinterpret_cast< | 
 |                           typename VectorConfig<sizeof(DType)*nvec>::IndexType *>(output); | 
 |     VectorType<DType, nvec> ret(vector_output[i]); | 
 |     ret += value; | 
 |     vector_output[i] = ret.y; | 
 |   } | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType identity(const DType val) { | 
 |   return val; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType add(const DType a, const DType2 b) { | 
 |   return a + b; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType sub(const DType a, const DType2 b) { | 
 |   return a - b; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType mul(const DType a, const DType2 b) { | 
 |   return a * b; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType div(const DType a, const DType2 b) { | 
 |   return a / b; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType rdiv(const DType a, const DType2 b) { | 
 |   return b / a; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType power(const DType a, const DType2 b) { | 
 |   return powf(a, b); | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType rpow(const DType a, const DType2 b) { | 
 |   return powf(b, a); | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType max(const DType a, const DType2 b) { | 
 |   return a > b ? a : b; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType min(const DType a, const DType2 b) { | 
 |   return a < b ? a : b; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2> | 
 | inline DType hypot(const DType a, const DType2 b) { | 
 |   return hypotf(a, b); | 
 | } | 
 |  | 
 | template <typename OutType, typename DType> | 
 | inline typename LoadType<OutType>::Type cast(const DType val) { | 
 |   return static_cast<typename LoadType<OutType>::Type>(val); | 
 | } | 
 |  | 
 | // activations | 
 |  | 
 | template <typename DType> | 
 | inline DType relu(const DType val) { | 
 |   return val > 0 ? val : 0; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType sigmoid(const DType val) { | 
 |   return 1.f/(1 + expf(-val)); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType softrelu(const DType val) { | 
 |   return logf(1 + expf(val)); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType softsign(const DType val) { | 
 |   return val / (1 + fabsf(val)); | 
 | } | 
 |  | 
 | // exp and log | 
 |  | 
 | template <typename DType> | 
 | inline DType exp(const DType val) { | 
 |   return expf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType expm1(const DType val) { | 
 |   return expm1f(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType log(const DType val) { | 
 |   return logf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType log10(const DType val) { | 
 |   return log10f(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType log2(const DType val) { | 
 |   return log2f(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType log1p(const DType val) { | 
 |   return log1pf(val); | 
 | } | 
 |  | 
 | // trigonometric | 
 |  | 
 | constexpr double pi = 3.14159265358979323846; | 
 |  | 
 | template <typename DType> | 
 | inline DType degrees(const DType val) { | 
 |   return (val / pi) * 180; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType radians(const DType val) { | 
 |   return (val / 180.0) * pi; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType sin(const DType val) { | 
 |   return sinf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType cos(const DType val) { | 
 |   return cosf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType tan(const DType val) { | 
 |   return tanf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType arcsin(const DType val) { | 
 |   return asinf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType arccos(const DType val) { | 
 |   return acosf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType arctan(const DType val) { | 
 |   return atanf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType sinh(const DType val) { | 
 |   return sinhf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType cosh(const DType val) { | 
 |   return coshf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType tanh(const DType val) { | 
 |   return tanhf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType arcsinh(const DType val) { | 
 |   return asinhf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType arccosh(const DType val) { | 
 |   return acoshf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType arctanh(const DType val) { | 
 |   return atanhf(val); | 
 | } | 
 |  | 
 | // sqrt | 
 |  | 
 | template <typename DType> | 
 | inline DType sqrt(const DType val) { | 
 |   return sqrtf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType rsqrt(const DType val) { | 
 |   return rsqrtf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType cbrt(const DType val) { | 
 |   return cbrtf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType rcbrt(const DType val) { | 
 |   return rcbrtf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType square(const DType val) { | 
 |   return val * val; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline typename LoadType<DType>::Type zero(const DType val) { | 
 |   return 0; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline typename LoadType<DType>::Type zero() { | 
 |   return 0; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline typename LoadType<DType>::Type one(const DType val) { | 
 |   return 1; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline typename LoadType<DType>::Type one() { | 
 |   return 1; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType round(const DType val) { | 
 |   return roundf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType rint(const DType val) { | 
 |   return rintf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType fix(const DType val) { | 
 |     const auto floor = floorf(val); | 
 |     const auto ceil = ceilf(val); | 
 |     return (floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType floor(const DType val) { | 
 |     return floorf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType ceil(const DType val) { | 
 |     return ceilf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType trunc(const DType val) { | 
 |     return truncf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType clip(const DType val, const float a_min, const float a_max) { | 
 |   return max(min(val, a_max), a_min); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType sign(const DType val) { | 
 |   if (val < 0) return -1; | 
 |   return val > 0 ? 1 : 0; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType reciprocal(const DType val) { | 
 |   return 1.0f / val; | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType abs(const DType val) { | 
 |   return fabsf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType gamma(const DType val) { | 
 |   return tgammaf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType gammaln(const DType val) { | 
 |   return lgammaf(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType erf(const DType val) { | 
 |   return erff(val); | 
 | } | 
 |  | 
 | template <typename DType> | 
 | inline DType erfinv(const DType val) { | 
 |   return erfinvf(val); | 
 | } | 
 |  | 
 | template <typename DType1, typename DType2> | 
 | inline DType1 smooth_l1(const DType1 val, const DType2 scalar) { | 
 |   const auto bsq = scalar * scalar; | 
 |   const auto ibsq = 1.0f / bsq; | 
 |   if (val > ibsq) { | 
 |     return val - 0.5f * ibsq; | 
 |   } else if (val < -ibsq) { | 
 |     return -val - 0.5f * ibsq; | 
 |   } else { | 
 |     return 0.5f * val * val * bsq; | 
 |   } | 
 | } | 
 |  | 
 | }  // namespace op | 
 |  | 
 | )code"; | 
 |  | 
 | const char backward_function_definitions[] = R"code( | 
 |  | 
 | namespace op { | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_relu(const DType val, const DTypeGrad grad) { | 
 |   return val > 0 ? grad : 0; | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_sigmoid(const DType out, const DTypeGrad grad) { | 
 |   return grad * out * (1 - out); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_softrelu(const DType val, const DTypeGrad grad) { | 
 |   return grad * sigmoid(val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_softsign(const DType val, const DTypeGrad grad) { | 
 |   const DType ap1 = 1 + fabsf(val); | 
 |   return grad / (ap1 * ap1); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_exp(const DType val, const DTypeGrad grad) { | 
 |   return grad * expf(val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_expm1(const DType val, const DTypeGrad grad) { | 
 |   return grad * expf(val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_log(const DType val, const DTypeGrad grad) { | 
 |   return grad / val; | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_log10(const DType val, const DTypeGrad grad) { | 
 |   return grad / (val * logf(10)); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_log2(const DType val, const DTypeGrad grad) { | 
 |   return grad / (val * logf(2)); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_log1p(const DType val, const DTypeGrad grad) { | 
 |   return grad / (1 + val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_sin(const DType val, const DTypeGrad grad) { | 
 |   return grad * cosf(val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_cos(const DType val, const DTypeGrad grad) { | 
 |   return -grad * sinf(val); | 
 | } | 
 |  | 
 | // Uses output from tan | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_tan(const DType out, const DTypeGrad grad) { | 
 |   return grad * (out * out + 1); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_arcsin(const DType val, const DTypeGrad grad) { | 
 |   return grad / sqrtf(1 - val*val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_arccos(const DType val, const DTypeGrad grad) { | 
 |   return -grad / sqrtf(1 - val*val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_arctan(const DType val, const DTypeGrad grad) { | 
 |   return grad / (1 + val*val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_sinh(const DType val, const DTypeGrad grad) { | 
 |   return grad * coshf(val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_cosh(const DType val, const DTypeGrad grad) { | 
 |   return grad * sinhf(val); | 
 | } | 
 |  | 
 | // Uses tanh output | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_tanh(const DType out, const DTypeGrad grad) { | 
 |   return grad * (1 - out * out); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_arcsinh(const DType val, const DTypeGrad grad) { | 
 |   return grad / sqrtf(val * val + 1); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_arccosh(const DType val, const DTypeGrad grad) { | 
 |   return grad / sqrtf(val * val - 1); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_arctanh(const DType val, const DTypeGrad grad) { | 
 |   return grad / (1 - val * val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_sqrt(const DType out, const DTypeGrad grad) { | 
 |   return 0.5 * grad / out; | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_rsqrt(const DType val, const DTypeGrad grad) { | 
 |   const DType inv = 1 / val; | 
 |   return -0.5 * grad * sqrtf(inv) * inv; | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_cbrt(const DType out, const DTypeGrad grad) { | 
 |   return grad / (3.0f * out * out); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_rcbrt(const DType val, const DTypeGrad grad) { | 
 |   const DType inv = 1 / val; | 
 |   return -1.f/3.f * grad * cbrtf(inv) * inv; | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_square(const DType val, const DTypeGrad grad) { | 
 |   return 2 * val * grad; | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_clip(const DType val, const DTypeGrad grad, const float a_min, const float a_max) { | 
 |   if (val > a_max || val < a_min) { | 
 |     return 0; | 
 |   } else { | 
 |     return grad; | 
 |   } | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_reciprocal(const DType val, const DTypeGrad grad) { | 
 |   return -grad / (val * val); | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_erf(const DType val, const DTypeGrad grad) { | 
 |   return 2.0f / sqrt(pi) * exp(-(val*val)) * grad; | 
 | } | 
 |  | 
 | template <typename DType, typename DTypeGrad> | 
 | inline DTypeGrad backward_erfinv(const DType val, const DTypeGrad grad) { | 
 |   return 0.5f * sqrt(pi) * exp(val * val) * grad; | 
 | } | 
 |  | 
 | template <typename DType, typename DType2, typename DTypeGrad> | 
 | inline DTypeGrad backward_smooth_l1(const DType val, const DType2 scalar, const DTypeGrad grad) { | 
 |   auto bsq = scalar * scalar; | 
 |   auto ibsq = 1.0f / bsq; | 
 |   if (val > ibsq) { | 
 |     return grad; | 
 |   } else if (val < -ibsq) { | 
 |     return -grad; | 
 |   } else { | 
 |     return bsq * val * grad; | 
 |   } | 
 | } | 
 |  | 
 | }  // namespace op | 
 |  | 
 | )code"; | 
 |  | 
 | const char kernel_begin[] = R"code( | 
 | const int tid = threadIdx.x + blockIdx.x * blockDim.x; | 
 | for (int i = tid; i < N; i+= gridDim.x * blockDim.x) { | 
 |     int offset = i*nvec; | 
 | )code"; | 
 |  | 
 | const char kernel_end[] = R"code(} | 
 | } | 
 | )code"; | 
 |  | 
 | }  // namespace fusion | 
 |  | 
 | }  // namespace mxnet | 
 |  | 
 | #endif  // MXNET_USE_CUDA | 
 |  | 
 | #endif  // MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ |