blob: cc14e85cace5caa6be4ec7cf0a20261cdaac2dcb [file] [log] [blame]
/*******************************************************************************
* Copyright 2016 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file mkl_concat-inl.h
* \brief
* \author lingyan.guo@intel.com
* zhenlin.luo@intel.com
*
*******************************************************************************/
#ifndef MXNET_OPERATOR_MKL_MKL_CONCAT_INL_H_
#define MXNET_OPERATOR_MKL_MKL_CONCAT_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <utility>
#include "../operator_common.h"
#include "../channel_op_common.h"
#include "./mkl_util-inl.h"
namespace mxnet {
namespace op {
template<typename xpu, typename DType>
class MKLConcatOp : public Operator {
public:
static std::string getName() {
return "MKLConcatOp";
}
explicit MKLConcatOp(ConcatParam param)
: size_(param.num_args), dimension_(param.dim), init_mkldnn_(false) {
concatFwd_ = static_cast<dnnPrimitive_t>(NULL);
concatBwd_ = static_cast<dnnPrimitive_t>(NULL);
fwd_top_data_ = MKLData<DType>::create();
bwd_top_diff_ = MKLData<DType>::create();
num_concats_ = param.num_args;
}
virtual ~MKLConcatOp() {
dnnDelete<DType>(concatFwd_);
dnnDelete<DType>(concatBwd_);
}
private:
void LayerSetUp(const std::vector<mshadow::Tensor<xpu, 4, DType> > &data,
const mshadow::Tensor<xpu, 4, DType> &out,
size_t data_shape_size, size_t *split_channels_) {
size_t dim_src = data_shape_size;
size_t dim_dst = dim_src;
num_concats_ = size_;
channels_ = 0;
for (size_t i = 1; i < num_concats_; ++i) {
for (size_t j = 1; j < data_shape_size; ++j) {
if (j == dimension_) continue;
CHECK_EQ(data[0].shape_[j], data[i].shape_[j]);
}
}
for (size_t i = 0; i < num_concats_; ++i) {
CHECK_EQ((int)dim_src, data[i].shape_.kDimension);
fwd_bottom_data_.push_back(MKLData<DType>::create());
bwd_bottom_diff_.push_back(MKLData<DType>::create());
fwd_bottom_data_[i]->name = "fwd_bottom_data_[i]";
bwd_bottom_diff_[i]->name = "bwd_bottom_data[i]";
size_t *sizes_src = new size_t[dim_src];
size_t *strides_src = new size_t[dim_src];
for (size_t d = 0; d < dim_src; ++d) {
sizes_src[d] = data[i].shape_[dim_src - d - 1];
strides_src[d] = (d == 0) ? 1 : strides_src[d - 1] * sizes_src[d - 1];
}
split_channels_[i] = data[i].shape_[1];
channels_ += split_channels_[i];
fwd_bottom_data_[i]->create_user_layout(dim_src, sizes_src, strides_src);
bwd_bottom_diff_[i]->create_user_layout(dim_src, sizes_src, strides_src);
delete[] sizes_src;
delete[] strides_src;
}
size_t *sizes_dst = new size_t[dim_dst];
size_t *strides_dst = new size_t[dim_dst];
for (size_t d = 0; d < dim_dst; ++d) {
if (d == 2)
sizes_dst[d] = channels_;
else
sizes_dst[d] = data[0].shape_[dim_dst - 1 - d];
strides_dst[d] = (d == 0) ? 1 : strides_dst[d - 1] * sizes_dst[d - 1];
}
bwd_top_diff_->create_user_layout(dim_dst, sizes_dst, strides_dst);
fwd_top_data_->create_user_layout(dim_dst, sizes_dst, strides_dst);
delete[] sizes_dst;
delete[] strides_dst;
concatFwd_ = NULL;
concatBwd_ = NULL;
}
public:
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(static_cast<int>(in_data.size()), size_);
CHECK_EQ(out_data.size(), 1);
CHECK_LT(dimension_, (size_t)in_data[concat_enum::kData0].ndim());
Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Tensor<xpu, 4, DType> > data(size_);
Tensor<xpu, 4, DType> out;
if (in_data[0].ndim() == 2) {
for (int i = 0; i < size_; ++i) {
Shape<4> dshape = Shape4(in_data[i].shape_[0],
in_data[i].shape_[1], 1, 1);
data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_data[i], dshape, s);
}
Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
out_data[concat_enum::kOut].shape_[1], 1, 1);
out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_data[concat_enum::kOut], dshape, s);
} else if (in_data[0].ndim() == 3) {
for (int i = 0; i < size_; ++i) {
Shape<4> dshape = Shape4(in_data[i].shape_[0],
in_data[i].shape_[1], in_data[i].shape_[2], 1);
data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_data[i], dshape, s);
}
Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
out_data[concat_enum::kOut].shape_[1],
out_data[concat_enum::kOut].shape_[2], 1);
out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_data[concat_enum::kOut], dshape, s);
} else {
for (int i = 0; i < size_; ++i) {
data[i] = mkl_experimental_direct_get<xpu, 4, DType>(in_data[i], s);
}
out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[concat_enum::kOut], s);
}
size_t *split_channels_ = new size_t[num_concats_];
if (!init_mkldnn_) {
init_mkldnn_ = true;
LayerSetUp(data, out, 4, split_channels_);
}
dnnError_t e;
std::vector<void*> bottom_data;
bool isFirstPass = (concatFwd_ == NULL);
dnnLayout_t *layouts = NULL;
if (isFirstPass) {
layouts = new dnnLayout_t[num_concats_];
}
for (size_t i = 0; i < num_concats_; i++) {
void * bottom_i = NULL;
#if MKL_EXPERIMENTAL == 1
bottom_i = mkl_prv_data<DType>(in_data[i]);
if (bottom_i != NULL) {
if (isFirstPass) {
std::shared_ptr<MKLData<DType> > mem_descr =
mkl_get_mem_desc<DType>(in_data[i].Mkl_mem_);
fwd_bottom_data_[i] = mem_descr;
layouts[i] = mem_descr->layout_int;
}
}
#endif
if (bottom_i == NULL) {
bottom_i = data[i].dptr_;
if (isFirstPass) {
layouts[i] = fwd_bottom_data_[i]->layout_usr;
}
}
bottom_data.push_back(reinterpret_cast<void *>(bottom_i));
}
if (isFirstPass) {
e = dnnConcatCreate<DType>(&concatFwd_, NULL, num_concats_, layouts);
CHECK_EQ(e, E_SUCCESS);
fwd_top_data_->create_internal_layout(concatFwd_, dnnResourceDst);
bwd_top_diff_->create_internal_layout(concatFwd_, dnnResourceDst);
e = dnnSplitCreate<DType>(&concatBwd_, NULL, num_concats_,
bwd_top_diff_->layout_int, split_channels_);
CHECK_EQ(e, E_SUCCESS);
for (size_t n = 0; n < num_concats_; ++n) {
fwd_bottom_data_[n]->create_internal_layout(concatFwd_,
(dnnResourceType_t)(dnnResourceMultipleSrc + n));
bwd_bottom_diff_[n]->create_internal_layout(concatBwd_,
(dnnResourceType_t)(dnnResourceMultipleDst + n));
}
}
delete[] layouts;
void *concat_res[dnnResourceNumber];
for (size_t i = 0; i < num_concats_; ++i) {
concat_res[dnnResourceMultipleSrc + i]
= reinterpret_cast<void*>(bottom_data[i]);
}
std::shared_ptr<MKLMemHolder> top_mem = NULL;
#if MKL_EXPERIMENTAL == 1
top_mem = out_data[concat_enum::kOut].Mkl_mem_;
#endif
concat_res[dnnResourceDst] = fwd_top_data_->get_output_ptr(out.dptr_,
fwd_top_data_, top_mem);
e = dnnExecute<DType>(concatFwd_, concat_res);
CHECK_EQ(e, E_SUCCESS);
delete[] split_channels_;
}
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(out_grad.size(), 1);
CHECK_EQ(in_grad.size(), static_cast<size_t>(size_));
Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Tensor<xpu, 4, DType> > grad_in(size_);
Tensor<xpu, 4, DType> grad;
if (in_grad[0].ndim() == 2) {
Shape<4> dshape = Shape4(out_grad[concat_enum::kOut].shape_[0],
out_grad[concat_enum::kOut].shape_[1], 1, 1);
grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_grad[concat_enum::kOut], dshape, s);
for (int i = 0; i < size_; ++i) {
dshape = Shape4(in_grad[i].shape_[0],
in_grad[i].shape_[1], 1, 1);
grad_in[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_grad[i], dshape, s);
}
} else if (in_grad[0].ndim() == 3) {
Shape<4> dshape = Shape4(out_grad[concat_enum::kOut].shape_[0],
out_grad[concat_enum::kOut].shape_[1],
out_grad[concat_enum::kOut].shape_[2], 1);
grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_grad[concat_enum::kOut], dshape, s);
for (int i = 0; i < size_; ++i) {
dshape = Shape4(in_grad[i].shape_[0],
in_grad[i].shape_[1], in_grad[i].shape_[2], 1);
grad_in[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_grad[i], dshape, s);
}
} else {
grad = mkl_experimental_direct_get<xpu, 4, DType>(out_grad[concat_enum::kOut], s);
for (int i = 0; i < size_; ++i) {
grad_in[i] = mkl_experimental_direct_get<xpu, 4, DType>(in_grad[i], s);
}
}
int need_bwd = 0;
for (size_t n = 0; n < num_concats_; n++) {
need_bwd += req[n];
}
if (!need_bwd) {
return;
}
dnnError_t e;
void *concat_res[dnnResourceNumber];
concat_res[dnnResourceSrc] = bwd_top_diff_->get_converted_prv(grad.dptr_, true,
out_grad[concat_enum::kOut]);
for (size_t i = 0; i < num_concats_; ++i) {
std::shared_ptr<MKLMemHolder> bottom_diff_mem = NULL;
#if MKL_EXPERIMENTAL == 1
bottom_diff_mem = in_grad[i].Mkl_mem_;
#endif
concat_res[dnnResourceMultipleDst + i] = bwd_bottom_diff_[i]->get_output_ptr(
grad_in[i].dptr_, bwd_bottom_diff_[i], bottom_diff_mem);
}
e = dnnExecute<DType>(concatBwd_, concat_res);
CHECK_EQ(e, E_SUCCESS);
}
private:
int size_;
size_t dimension_;
bool init_mkldnn_;
dnnPrimitive_t concatFwd_;
dnnPrimitive_t concatBwd_;
std::shared_ptr<MKLData<DType> > fwd_top_data_;
std::vector< std::shared_ptr<MKLData<DType> > > fwd_bottom_data_;
std::shared_ptr<MKLData<DType> > bwd_top_diff_;
std::vector< std::shared_ptr<MKLData<DType> > > bwd_bottom_diff_;
size_t width_;
size_t height_;
size_t channels_;
size_t num_;
size_t num_concats_;
}; // class MKLConcatOp
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_MKL_MKL_CONCAT_INL_H_