blob: 94616bc2e7d7d453db710f73f51f4c8f26f7a2a3 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file fully_connect_op-inl.h
* \brief fully connect operator and symbol
*/
#ifndef MXNET_OPERATOR_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_FULLY_CONNECTED_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "./operator_common.h"
#include "./elemwise_op_common.h"
namespace mxnet {
namespace op {
// Declare enumeration of input order to make code more intuitive.
// These enums are only visible within this header
namespace fullc {
enum FullyConnectedOpInputs {kData, kWeight, kBias};
enum FullyConnectedOpOutputs {kOut};
} // fullc
struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
int num_hidden;
bool no_bias;
DMLC_DECLARE_PARAMETER(FullyConnectedParam) {
// TODO(bing) add support for boolean
DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1)
.describe("Number of hidden nodes of the output.");
DMLC_DECLARE_FIELD(no_bias).set_default(false)
.describe("Whether to disable bias parameter.");
}
};
/**
* \brief This is the implementation of fully connected operator.
* \tparam xpu The device that the op will be executed on.
*/
template<typename xpu, typename DType>
class FullyConnectedOp : public Operator {
public:
explicit FullyConnectedOp(FullyConnectedParam p) {
this->param_ = p;
}
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
if (req[fullc::kOut] == kNullOp) return;
CHECK_EQ(req[fullc::kOut], kWriteTo);
size_t expected = param_.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(out_data.size(), 1U);
// TODO(bing): check the BLAS Handle, be careful
// maybe need blas handle from context
// TODO(bing): judge shape to remove flatten op
Stream<xpu> *s = ctx.get_stream<xpu>();
#if defined(__CUDACC__)
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
<< "Must init CuBLAS handle in stream";
#endif // __CUDACC__
const TShape& ishape = in_data[fullc::kData].shape_;
const TShape& oshape = out_data[fullc::kOut].shape_;
Tensor<xpu, 2, DType> data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
out = dot(data, wmat.T());
if (!param_.no_bias) {
Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get<xpu, 1, DType>(s);
out += repmat(bias, data.size(0));
}
}
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(out_grad.size(), 1U);
size_t expected = param_.no_bias ? 2 : 3;
CHECK(in_data.size() == expected && in_grad.size() == expected);
CHECK_EQ(req.size(), expected);
// TODO(bing): check the BLAS Handle, be careful
// maybe need blas handle from context
Stream<xpu> *s = ctx.get_stream<xpu>();
const TShape& ishape = in_data[fullc::kData].shape_;
const TShape& oshape = out_grad[fullc::kOut].shape_;
Tensor<xpu, 2, DType> data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> grad = out_grad[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
#if defined(__CUDACC__)
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
<< "Must init CuBLAS handle in stream";
#endif
// backprop
CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
// gradient of weight
Tensor<xpu, 2, DType> gwmat = in_grad[fullc::kWeight].get<xpu, 2, DType>(s);
Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data));
// gradient of bias
if (!param_.no_bias) {
Tensor<xpu, 1, DType> gbias = in_grad[fullc::kBias].get<xpu, 1, DType>(s);
Assign(gbias, req[fullc::kBias], sum_rows(grad));
}
// gradient of data
Tensor<xpu, 2, DType> gdata = in_grad[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
Assign(gdata, req[fullc::kData], dot(grad, wmat));
}
private:
FullyConnectedParam param_;
}; // class FullyConnectedOp
// Decalre Factory function, used for dispatch specialization
template<typename xpu>
Operator* CreateOp(FullyConnectedParam param, int dtype,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
Context ctx);
#if DMLC_USE_CXX11
class FullyConnectedProp : public OperatorProperty {
public:
std::vector<std::string> ListArguments() const override {
if (!param_.no_bias) {
return {"data", "weight", "bias"};
} else {
return {"data", "weight"};
}
}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}
std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
if (!param_.no_bias) {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
}
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[fullc::kData];
TShape oshape = (*out_shape)[0];
// require data to be known
if (dshape.ndim() == 0) return false;
index_t num_input = dshape.ProdShape(1, dshape.ndim());
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input));
if (!param_.no_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param_.num_hidden));
}
SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param_.num_hidden));
if (oshape.ndim() != 0) {
dshape[0] = oshape[0];
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape);
}
return true;
}
bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), 1U);
nnvm::NodeAttrs attrs;
attrs.name = "FullyConnected";
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_type, out_type, -1);
}
OperatorProperty* Copy() const override {
FullyConnectedProp* fc_sym = new FullyConnectedProp();
fc_sym->param_ = this->param_;
return fc_sym;
}
std::string TypeString() const override {
return "FullyConnected";
}
// decalre dependency and inplace optimization options
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
return {out_grad[fullc::kOut], in_data[fullc::kData], in_data[fullc::kWeight]};
}
std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const override {
return {{in_data[fullc::kData], in_grad[fullc::kData]}};
}
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;
private:
FullyConnectedParam param_;
}; // class FullyConnectedSymbol
#endif
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_FULLY_CONNECTED_INL_H_