blob: 5e296704b6dda23760031be63b7037c6ac7d886f [file] [log] [blame]
/*******************************************************************************
* Copyright 2016 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file mkl_fully_connected-inl.h
* \brief
* \author zhenlin.luo@intel.com
* lingyan.guo@intel.com
*
*
*******************************************************************************/
#ifndef MXNET_OPERATOR_MKL_MKL_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_MKL_MKL_FULLY_CONNECTED_INL_H_
#include <string>
#include <algorithm>
#include <vector>
#include "../activation-inl.h"
#include "./mkl_util-inl.h"
namespace mxnet {
namespace op {
template<typename xpu, typename DType>
class MKLFullyConnectedOp : public Operator {
public:
explicit MKLFullyConnectedOp(const FullyConnectedParam& p,
const std::vector<TShape>& in_shapes,
const std::vector<TShape>& out_shapes):
param_(p) {
LayerSetUp(in_shapes, out_shapes);
}
~MKLFullyConnectedOp() {
dnnDelete<DType>(fullyConnectedFwd);
dnnDelete<DType>(fullyConnectedBwdData);
dnnDelete<DType>(fullyConnectedBwdFilter);
dnnDelete<DType>(fullyConnectedBwdBias);
}
static std::string getName() {
return "MKLFullyConnectedOp";
}
private:
void LayerSetUp(const std::vector<TShape>& in_shapes,
const std::vector<TShape>& out_shapes) {
const TShape& ishape = in_shapes[fullc::kData];
const size_t dim = 4;
const size_t src_sizes[4] = {1, 1, ishape.ProdShape(1, ishape.ndim()), ishape[0]};
const size_t dst_sizes[2] = {param_.num_hidden, ishape[0]};
const size_t output_channels = param_.num_hidden;
dnnPrimitiveAttributes_t attributes = NULL;
MKLDNN_CALL(dnnPrimitiveAttributesCreate<DType>(&attributes));
if (!param_.no_bias) {
MKLDNN_CALL(dnnInnerProductCreateForwardBias<DType>(
&fullyConnectedFwd,
attributes,
dim,
src_sizes,
output_channels));
} else {
MKLDNN_CALL(dnnInnerProductCreateForward<DType>(
&fullyConnectedFwd,
attributes,
dim,
src_sizes,
output_channels));
}
MKLDNN_CALL(dnnInnerProductCreateBackwardData<DType>(
&fullyConnectedBwdData,
attributes,
dim,
src_sizes,
output_channels));
MKLDNN_CALL(dnnInnerProductCreateBackwardFilter<DType>(
&fullyConnectedBwdFilter,
attributes,
dim,
src_sizes,
output_channels));
if (!param_.no_bias) {
MKLDNN_CALL(dnnInnerProductCreateBackwardBias<DType>(
&fullyConnectedBwdBias,
attributes,
2,
dst_sizes));
}
// TODO(minjie): Shouldn't `attributes` be destroyed?
}
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
void* res_fullyConnected[dnnResourceNumber];
if (req[fullc::kOut] == kNullOp) return;
CHECK_EQ(req[fullc::kOut], kWriteTo);
CHECK_EQ(in_data.size(), param_.no_bias ? 2 : 3);
CHECK_EQ(out_data.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TShape& ishape = in_data[fullc::kData].shape_;
const TShape& oshape = out_data[fullc::kOut].shape_;
Tensor<xpu, 4, DType> data;
Tensor<xpu, 4, DType> out;
Shape4(in_data[fullc::kData].shape_[0], in_data[fullc::kData].shape_[1], 1, 1);
Shape<4> dshape = Shape4(ishape[0], ishape.ProdShape(1, ishape.ndim()), 1, 1);
Shape<4> odshape = Shape4(oshape[0], oshape.ProdShape(1, oshape.ndim()), 1, 1);
data = in_data[fullc::kData].get_with_shape<xpu, 4, DType>(dshape, s);
out = out_data[fullc::kOut].get_with_shape<xpu, 4, DType>(odshape, s);
res_fullyConnected[dnnResourceSrc] =
reinterpret_cast<void *>(in_data[fullc::kData].dptr_);
res_fullyConnected[dnnResourceDst] =
reinterpret_cast<void *>(out_data[fullc::kOut].dptr_);
res_fullyConnected[dnnResourceFilter] =
reinterpret_cast<void *>(in_data[fullc::kWeight].dptr_);
if (!param_.no_bias) {
res_fullyConnected[dnnResourceBias] = reinterpret_cast<void *>(in_data[fullc::kBias].dptr_);
}
MKLDNN_CALL(dnnExecute<DType>(fullyConnectedFwd, res_fullyConnected));
}
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
void* res_fullyConnected[dnnResourceNumber];
CHECK_EQ(out_grad.size(), 1);
const size_t expected = param_.no_bias ? 2 : 3;
CHECK(in_data.size() == expected && in_grad.size() == expected);
CHECK_EQ(req.size(), expected);
res_fullyConnected[dnnResourceSrc] =
reinterpret_cast<void *>(in_data[fullc::kData].dptr_);
res_fullyConnected[dnnResourceFilter] =
reinterpret_cast<void *>(in_data[fullc::kWeight].dptr_);
res_fullyConnected[dnnResourceDiffDst] =
reinterpret_cast<void *>(out_grad[fullc::kOut].dptr_);
res_fullyConnected[dnnResourceDiffSrc] =
reinterpret_cast<void *>(in_grad[fullc::kData].dptr_);
res_fullyConnected[dnnResourceDiffFilter] =
reinterpret_cast<void *>(in_grad[fullc::kWeight].dptr_);
if (!param_.no_bias) {
res_fullyConnected[dnnResourceDiffBias] =
reinterpret_cast<void *>(in_grad[fullc::kBias].dptr_);
}
MKLDNN_CALL(dnnExecute<DType>(fullyConnectedBwdFilter, res_fullyConnected));
if (!param_.no_bias) {
MKLDNN_CALL(dnnExecute<DType>(fullyConnectedBwdBias, res_fullyConnected));
}
MKLDNN_CALL(dnnExecute<DType>(fullyConnectedBwdData, res_fullyConnected));
}
private:
dnnPrimitive_t fullyConnectedFwd{nullptr};
dnnPrimitive_t fullyConnectedBwdData{nullptr};
dnnPrimitive_t fullyConnectedBwdFilter{nullptr};
dnnPrimitive_t fullyConnectedBwdBias{nullptr};
const FullyConnectedParam param_;
}; // class MKLFullyConnectedOp
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_MKL_MKL_FULLY_CONNECTED_INL_H_