blob: b5967f4de294806d3c2ff092599130f8b591a390 [file] [log] [blame]
/*******************************************************************************
* Copyright 2016 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file mkl_batch_norm-inl.h
* \brief
* \author lingyan.guo@intel.com
* zhenlin.luo@intel.com
*
*******************************************************************************/
#ifndef MXNET_OPERATOR_MKL_MKL_BATCH_NORM_INL_H_
#define MXNET_OPERATOR_MKL_MKL_BATCH_NORM_INL_H_
#include <mxnet/storage.h>
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "./mkl_util-inl.h"
namespace mxnet {
namespace op {
template<typename xpu, typename DType>
class MKLBatchNormOp : public Operator {
public:
explicit MKLBatchNormOp(BatchNormParam param) {
this->param_ = param;
fwd_top_data = MKLData<DType>::create();
fwd_bottom_data = MKLData<DType>::create();
bwd_top_diff = MKLData<DType>::create();
bwd_bottom_diff = MKLData<DType>::create();
scaleShift_space.dptr = NULL;
scaleShiftDiff_space.dptr = NULL;
}
virtual ~MKLBatchNormOp() {
if (batchNormFwdInference != NULL) dnnDelete<DType>(batchNormFwdInference);
if (batchNormFwdTraining != NULL) dnnDelete<DType>(batchNormFwdTraining);
if (batchNormBwdScaleShift != NULL) dnnDelete<DType>(batchNormBwdScaleShift);
dnnLayoutDelete<DType>(layout_usr_);
if (scaleShift_space.dptr)
Storage::Get()->Free(scaleShift_space);
if (scaleShiftDiff_space.dptr)
Storage::Get()->Free(scaleShiftDiff_space);
}
static std::string getName() {
return "MKLBatchNormOp";
}
private:
void LayerSetUp(const mshadow::Tensor<xpu, 4, DType> &data,
const mshadow::Tensor<xpu, 4, DType> &out) {
eps_ = param_.eps;
size_t dim = 4, sizes[4], strides[4];
channels_ = data.shape_[1];
height_ = data.shape_[2];
width_ = data.shape_[3];
num_ = data.shape_[0];
sizes[0] = width_;
sizes[1] = height_;
sizes[2] = channels_;
sizes[3] = num_;
strides[0] = 1;
strides[1] = sizes[0];
strides[2] = sizes[0] * sizes[1];
strides[3] = sizes[0] * sizes[1] * sizes[2];
// Names are for debugging only
fwd_bottom_data->name = "fwd_bottom_data @ " + getName();
fwd_top_data->name = "fwd_top_data @ " + getName();
bwd_bottom_diff->name = "bwd_bottom_diff @ " + getName();
bwd_top_diff->name = "bwd_top_diff @ " + getName();
dnnError_t e;
e = dnnLayoutCreate<DType>(&layout_usr_, dim, sizes, strides);
CHECK_EQ(e, E_SUCCESS);
fwd_bottom_data->create_user_layout(dim, sizes, strides);
fwd_top_data->create_user_layout(dim, sizes, strides);
bwd_bottom_diff->create_user_layout(dim, sizes, strides);
bwd_top_diff->create_user_layout(dim, sizes, strides);
// Primitives will be allocated during the first fwd pass
batchNormFwdInference = NULL;
batchNormFwdTraining = NULL;
batchNormBwdScaleShift = NULL;
int scaleShift_size = channels_*2*sizeof(DType);
scaleShift_space = Storage::Get()->Alloc(scaleShift_size, Context::CPU());
scaleShiftDiff_space = Storage::Get()->Alloc(scaleShift_size, Context::CPU());
DType * scaleShift_buf = reinterpret_cast<DType*>(scaleShift_space.dptr);
/*!use_weight_bias_*/
for (int i = 0; i < channels_; i++) {
scaleShift_buf[i] = 1.0;
scaleShift_buf[channels_ + i] = 0;
}
}
public:
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 3);
CHECK_EQ(aux_states.size(), 2);
if (ctx.is_train) {
CHECK_EQ(out_data.size(), 3);
CHECK_EQ(req.size(), 3);
} else {
CHECK_GE(out_data.size(), 1);
CHECK_GE(req.size(), 1);
CHECK_EQ(req[batchnorm::kOut], kWriteTo);
}
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> data;
Tensor<xpu, 4, DType> out;
if (in_data[batchnorm::kData].ndim() == 2) {
Shape<4> dshape = Shape4(in_data[batchnorm::kData].shape_[0],
in_data[batchnorm::kData].shape_[1], 1, 1);
data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_data[batchnorm::kData], dshape, s);
out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_data[batchnorm::kOut], dshape, s);
} else {
data = mkl_experimental_direct_get<xpu, 4, DType>(in_data[batchnorm::kData], s);
out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[batchnorm::kOut], s);
}
// const real_t scale = static_cast<real_t>(in_data[batchnorm::kData].shape_[1]) /
// static_cast<real_t>(in_data[batchnorm::kData].shape_.Size());
Tensor<xpu, 1, DType> slope = in_data[batchnorm::kGamma].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> bias = in_data[batchnorm::kBeta].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, DType>(s);
if (param_.fix_gamma)
slope = 1.f;
dnnError_t e;
if (!init_mkldnn_) {
LayerSetUp(data, out);
init_mkldnn_ = true;
}
void* bottom_data = NULL;
#if MKL_EXPERIMENTAL == 1
bottom_data =
reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[batchnorm::kData]));
#endif
int bwd_flags = dnnUseScaleShift;
if (param_.use_global_stats)
bwd_flags = dnnUseScaleShift | dnnUseInputMeanVariance;
#if MKL_EXPERIMENTAL == 1
if (NULL != bottom_data) {
// Is it the first pass? Create a primitive.
if (batchNormFwdInference == NULL) {
std::shared_ptr<MKLMemHolder> bottom_data_mem = in_data[batchnorm::kData].Mkl_mem_;
std::shared_ptr<PrvMemDescr> bottom_prv_desc = bottom_data_mem->get_prv_descriptor();
CHECK(bottom_prv_desc->get_descr_type() == PrvMemDescr::PRV_DESCR_MKL2017);
std::shared_ptr<MKLData<DType> > mem_descr
= std::static_pointer_cast<MKLData<DType>>(bottom_prv_desc);
CHECK(mem_descr != NULL);
fwd_bottom_data = mem_descr;
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdInference, NULL, mem_descr->layout_int, eps_,
dnnUseInputMeanVariance | dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdTraining, NULL, mem_descr->layout_int, eps_,
dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
fwd_top_data->create_internal_layout(batchNormFwdInference, dnnResourceDst);
bwd_top_diff->create_internal_layout(batchNormFwdInference, dnnResourceDst);
bwd_bottom_diff->create_internal_layout(batchNormFwdInference, dnnResourceSrc);
e = dnnBatchNormalizationCreateBackward_v2<DType>(
&batchNormBwdScaleShift, NULL, mem_descr->layout_int, eps_, bwd_flags);
CHECK_EQ(e, E_SUCCESS);
}
}
#endif
if (NULL == bottom_data) {
if (batchNormFwdInference == NULL) {
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdInference, NULL, layout_usr_, eps_,
dnnUseInputMeanVariance | dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdTraining, NULL, layout_usr_, eps_, dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
e = dnnBatchNormalizationCreateBackward_v2<DType>(
&batchNormBwdScaleShift, NULL, layout_usr_, eps_, bwd_flags);
CHECK_EQ(e, E_SUCCESS);
}
bottom_data = reinterpret_cast<void *>(data.dptr_);
}
DType * scaleShift_buf = reinterpret_cast<DType*>(scaleShift_space.dptr);
// use_weight_bias_
for (int i = 0; i < channels_; i++) {
scaleShift_buf[i] = (slope.dptr_)[i];
}
for (int i = 0; i < channels_; i++) {
scaleShift_buf[channels_ + i] = (bias.dptr_)[i];
}
void* BatchNorm_res[dnnResourceNumber];
BatchNorm_res[dnnResourceSrc] = bottom_data;
BatchNorm_res[dnnResourceScaleShift] = scaleShift_space.dptr;
BatchNorm_res[dnnResourceDst] = fwd_top_data->get_output_ptr(out.dptr_,
fwd_top_data, out_data[batchnorm::kOut]);
if (ctx.is_train && !param_.use_global_stats) {
Tensor<xpu, 1, DType> mean = out_data[batchnorm::kMean].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> var = out_data[batchnorm::kVar].get<xpu, 1, DType>(s);
CHECK(req[batchnorm::kMean] == kNullOp || req[batchnorm::kMean] == kWriteTo);
CHECK(req[batchnorm::kVar] == kNullOp || req[batchnorm::kVar] == kWriteTo);
BatchNorm_res[dnnResourceMean] = mean.dptr_;
BatchNorm_res[dnnResourceVariance] = var.dptr_;
e = dnnExecute<DType>(batchNormFwdTraining, BatchNorm_res);
CHECK_EQ(e, E_SUCCESS);
} else {
BatchNorm_res[dnnResourceMean] = moving_mean.dptr_;
BatchNorm_res[dnnResourceVariance] = moving_var.dptr_;
e = dnnExecute<DType>(batchNormFwdInference, BatchNorm_res);
CHECK_EQ(e, E_SUCCESS);
}
#if MKL_EXPERIMENTAL == 0
if (fwd_top_data->conversion_needed()) {
fwd_top_data->convert_from_prv(out.dptr_);
}
#endif
}
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(out_grad.size(), 1);
CHECK_EQ(in_data.size(), 3);
CHECK_EQ(out_data.size(), 3);
CHECK_EQ(in_grad.size(), 3);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> data, grad, grad_in;
if (in_data[batchnorm::kData].ndim() == 2) {
Shape<4> dshape = Shape4(out_grad[batchnorm::kOut].shape_[0],
out_grad[batchnorm::kOut].shape_[1], 1, 1);
data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_data[batchnorm::kData], dshape, s);
grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_grad[batchnorm::kOut], dshape, s);
grad_in = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_grad[batchnorm::kData], dshape, s);
} else {
data = mkl_experimental_direct_get<xpu, 4, DType>(in_data[batchnorm::kData], s);
grad = mkl_experimental_direct_get<xpu, 4, DType>(out_grad[batchnorm::kOut], s);
grad_in = mkl_experimental_direct_get<xpu, 4, DType>(in_grad[batchnorm::kData], s);
}
Tensor<xpu, 1, DType> slope = in_data[batchnorm::kGamma].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> gslope = in_grad[batchnorm::kGamma].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> gbias = in_grad[batchnorm::kBeta].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> mean = out_data[batchnorm::kMean].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> var = out_data[batchnorm::kVar].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, DType>(s);
if (param_.fix_gamma) slope = 1.f;
void* bottom_data = NULL;
#if MKL_EXPERIMENTAL == 1
bottom_data = reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[batchnorm::kData]));
#endif
if (NULL == bottom_data)
bottom_data = reinterpret_cast<void *>(data.dptr_);
dnnError_t e;
void* BatchNorm_res[dnnResourceNumber];
BatchNorm_res[dnnResourceSrc] = bottom_data;
BatchNorm_res[dnnResourceScaleShift] = scaleShift_space.dptr;
if (ctx.is_train && !param_.use_global_stats) {
int size = mean.size(0); // Tensor<xpu, 1, DType>
float * moving_mean_ptr = reinterpret_cast<float*>(moving_mean.dptr_);
float * mean_ptr = reinterpret_cast<float*>(mean.dptr_);
float * moving_var_ptr = reinterpret_cast<float*>(moving_var.dptr_);
float * var_ptr = reinterpret_cast<float*>(var.dptr_);
float minus_mom = (1 - param_.momentum);
for (int i = 0; i < size; i++) {
moving_mean_ptr[i] = moving_mean_ptr[i] * param_.momentum
+ mean_ptr[i] * minus_mom;
}
for (int i = 0; i < size; i++) {
moving_var_ptr[i] = moving_var_ptr[i] * param_.momentum
+ var_ptr[i] * minus_mom;
}
BatchNorm_res[dnnResourceMean] = mean.dptr_;
BatchNorm_res[dnnResourceVariance] = var.dptr_;
} else {
BatchNorm_res[dnnResourceMean] = moving_mean.dptr_;
BatchNorm_res[dnnResourceVariance] = moving_var.dptr_;
}
BatchNorm_res[dnnResourceDiffSrc] = bwd_bottom_diff->get_output_ptr(grad_in.dptr_,
bwd_bottom_diff, in_grad[batchnorm::kData]);
BatchNorm_res[dnnResourceDiffDst] = bwd_top_diff->get_converted_prv(grad.dptr_,
true, out_grad[batchnorm::kOut]);
BatchNorm_res[dnnResourceDiffScaleShift] = scaleShiftDiff_space.dptr;
e = dnnExecute<DType>(batchNormBwdScaleShift, BatchNorm_res);
CHECK_EQ(e, E_SUCCESS);
#if MKL_EXPERIMENTAL == 0
if (bwd_bottom_diff->conversion_needed()) {
bwd_bottom_diff->convert_from_prv(grad_in.dptr_);
}
#endif
DType * scaleShiftDiff_buf = reinterpret_cast<DType*>(scaleShiftDiff_space.dptr);
if (!param_.fix_gamma) {
// Store ScaleShift blobs
DType* diff_scale = gslope.dptr_;
for (int i = 0; i < channels_; i++) {
diff_scale[i] = scaleShiftDiff_buf[i];
}
} else {
int gslope_size = gslope.size(0);
float * gslope_ptr = reinterpret_cast<float*>(gslope.dptr_);
for (int i = 0; i < gslope_size; i++) {
*gslope_ptr++ = 0.0f;
}
}
DType* diff_shift = gbias.dptr_;
for (int i = 0; i < channels_; i++) {
diff_shift[i] = scaleShiftDiff_buf[channels_ + i];
}
}
private:
BatchNormParam param_;
DType eps_;
bool use_weight_bias_;
int num_;
int channels_;
int height_;
int width_;
bool init_mkldnn_ = false;
std::shared_ptr<MKLData<DType> > fwd_top_data;
std::shared_ptr<MKLData<DType> > fwd_bottom_data;
std::shared_ptr<MKLData<DType> > bwd_top_diff;
std::shared_ptr<MKLData<DType> > bwd_bottom_diff;
dnnPrimitive_t batchNormFwdInference = NULL;
dnnPrimitive_t batchNormFwdTraining = NULL;
dnnPrimitive_t batchNormBwdScaleShift = NULL;
Storage::Handle scaleShift_space;
Storage::Handle scaleShiftDiff_space;
dnnLayout_t layout_usr_ = NULL;
}; // class BatchNormOp
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_MKL_MKL_BATCH_NORM_INL_H_