blob: 08f57d15fbf8ece4583dd776ce30588fd29a9163 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file elemwise_sum.h
* \brief Function definition of elementwise sum
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_SUM_H_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_SUM_H_
#include <dmlc/logging.h>
#include <cstring>
#include <vector>
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../mshadow_op.h"
#include "../mxnet_op.h"
namespace mxnet {
namespace op {
struct Sum {
template<typename DType>
MSHADOW_XINLINE static DType sum(int i, const DType* a) {
return a[i];
}
template<typename DType, typename... DTypes>
MSHADOW_XINLINE static DType sum(int i, const DType* a, const DTypes... b) {
return a[i] + sum(i, b...);
}
template<typename DType, typename... DTypes>
MSHADOW_XINLINE static void Map(int i, DType* out, const OpReqType req, const DType* in0,
const DTypes... ins) {
KERNEL_ASSIGN(out[i], req, sum(i, in0, ins...));
}
};
template<typename xpu, typename DType>
void ElementWiseSumCompute_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& in_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& out_data) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
size_t size = in_data.size();
Stream<xpu> *s = ctx.get_stream<xpu>();
DType* out_dptr = out_data[0].dptr<DType>();
int out_size = static_cast<int>((out_data[0].Size() + DataType<DType>::kLanes - 1)
/DataType<DType>::kLanes);
switch (size) {
case 2: {
DType* in_0_dptr = in_data[0].dptr<DType>();
DType* in_1_dptr = in_data[1].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr, in_1_dptr);
break;
}
case 3: {
DType* in_0_dptr = in_data[0].dptr<DType>();
DType* in_1_dptr = in_data[1].dptr<DType>();
DType* in_2_dptr = in_data[2].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr, in_1_dptr, in_2_dptr);
break;
}
case 4: {
DType* in_0_dptr = in_data[0].dptr<DType>();
DType* in_1_dptr = in_data[1].dptr<DType>();
DType* in_2_dptr = in_data[2].dptr<DType>();
DType* in_3_dptr = in_data[3].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr, in_1_dptr, in_2_dptr,
in_3_dptr);
break;
}
default: {
DType* in_0_dptr = in_data[0].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr);
for (size_t i = 1; i < size; ++i) {
DType* in_dptr = in_data[i].dptr<DType>();
Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], out_dptr, in_dptr);
}
break;
}
}
}
template<typename xpu>
void ElementWiseSumCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(outputs.size(), 1U);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
ElementWiseSumCompute_<xpu, DType>(attrs, ctx, inputs, req, outputs);
});
}
template<typename xpu>
void ElementWiseSumComputeWithHalf2(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(outputs.size(), 1U);
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
ElementWiseSumCompute_<xpu, DType>(attrs, ctx, inputs, req, outputs);
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_SUM_H_