blob: ffb9898104816d47acc0b9a925abae962b781a80 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file operator_common.h
* \brief common internal header of most operators
* this header includes utility functions operator can use
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_
#define MXNET_OPERATOR_OPERATOR_COMMON_H_
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <mxnet/operator.h>
#include <mxnet/ndarray.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/base.h>
#include <istream>
#include <ostream>
#include <string>
#include <vector>
#include <algorithm>
#include "../common/cuda/utils.h"
#include "../common/utils.h"
namespace mxnet {
namespace op {
/*!
* \brief assign the expression to out according to request
* \param out the data to be assigned
* \param req the assignment request
* \param exp the expression
* \tparam OType output type
* \tparam Exp expression type
*/
#define Assign(out, req, exp) \
{ \
switch (req) { \
case kNullOp: \
break; \
case kWriteTo: \
case kWriteInplace: \
(out) = (exp); \
break; \
case kAddTo: \
(out) += (exp); \
break; \
default: \
LOG(FATAL) << "not reached"; \
} \
}
/*! \brief exception throwed by InferShape error */
struct InferShapeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferShapeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! \brief exception throwed by InferShape error */
struct InferTypeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferTypeError(const std::string& msg_, int index) : dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! \brief exception throwed by InferStorageType error */
struct InferStorageTypeError : public dmlc::Error {
/*! \brief analyze message */
std::string msg;
/*! \brief corresponding input index */
int index;
// constructor
InferStorageTypeError(const std::string& msg_, int index)
: dmlc::Error(msg_), msg(msg_), index(index) {}
};
/*! \brief check if shape is empty or contains unknown (0) dim.
* DEPRECATED. */
inline bool shape_is_none(const mxnet::TShape& x) {
return !mxnet::shape_is_known(x);
}
/*! \brief check if type is none (-1) */
inline bool type_is_none(const int& x) {
return x == -1;
}
/*! \brief check if type is none (-1) */
inline bool storage_type_is_none(const int& x) {
return x == -1;
}
/*! \brief check if shape is scalar({1}). */
inline bool shape_is_scalar(const mxnet::TShape& x) {
return x.ndim() == 0;
}
/*! \brief get string representation of shape */
inline std::string shape_string(const mxnet::TShape& x) {
std::ostringstream os;
os << x;
return os.str();
}
/*! \brief get string representation of data type */
inline std::string type_string(const int& x) {
switch (x) {
case mshadow::kFloat32:
return "float32";
case mshadow::kFloat64:
return "float64";
case mshadow::kFloat16:
return "float16";
case mshadow::kBfloat16:
return "bfloat16";
case mshadow::kInt8:
return "int8";
case mshadow::kUint8:
return "uint8";
case mshadow::kInt32:
return "int32";
case mshadow::kInt64:
return "int64";
}
return "unknown";
}
/*!
* \brief Assign x to y. Checks for compatiblity when y is not empty.
* Allow missing dim in both x and y (as 0).
* \param y target shape.
* \param x source shape.
* \return whether x and y are compatible.
*/
inline bool shape_assign(mxnet::TShape* y, const mxnet::TShape& x) {
if (!mxnet::ndim_is_known(*y)) {
*y = x;
return true;
} else if (y->ndim() != x.ndim()) {
return !mxnet::ndim_is_known(x);
} else {
for (int i = 0; i < y->ndim(); ++i) {
if (!mxnet::dim_size_is_known(*y, i)) {
(*y)[i] = x[i];
} else if ((*y)[i] != x[i] && x[i] >= 0) {
return false;
}
}
return true;
}
}
/*!
* \brief Assign x to y. Checks for compatiblity when y is not -1.
* \param y target type.
* \param x source type.
* \return whether x and y are compatible.
*/
inline bool type_assign(int* y, const int& x) {
if (*y == -1) {
*y = x;
return true;
} else if (*y != x && x != -1) {
return false;
}
return true;
}
/*!
* \brief Assign x to y. Checks for compatiblity when y is not DispatchMode::kUndefined.
* \param y target mode.
* \param x source mode.
* \return whether x and y are compatible.
*/
inline bool dispatch_mode_assign(DispatchMode* y, const DispatchMode& x) {
if (*y == DispatchMode::kUndefined) {
*y = x;
return true;
} else if (*y != x && x != DispatchMode::kUndefined) {
return false;
}
return true;
}
/*! \brief Register op name as an alias */
#define MXNET_ADD_SPARSE_OP_ALIAS(__name$) .add_alias("_sparse_" #__name$)
/*!
* \brief macro assign shape to out if out is unknown otherwise check consistency
* Use macro so we can see the error file more clearly
* \param shape_array the shape array to store the result
* \param index the index of in the array
* \param shape the inferred shape
*/
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \
{ \
if (!::mxnet::op::shape_assign(&(shape_array)[index], mxnet::TShape(shape))) { \
std::ostringstream os; \
os << "Shape inconsistent, Provided = " << (shape_array)[index] << ',' \
<< " inferred shape=" << shape; \
throw ::mxnet::op::InferShapeError(os.str(), index); \
} \
}
/*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param type_array the type array to store the result
* \param index the index of in the array
* \param type the inferred type
*/
#define TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Type inconsistent, Provided = " << ::mxnet::op::type_string((type_array)[index]) \
<< ',' << " inferred type = " << ::mxnet::op::type_string(type); \
throw ::mxnet::op::InferTypeError(os.str(), index); \
} \
}
/*!
* \brief macro assign storage type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param type_array the type array to store the result
* \param index the index of in the array
* \param type the inferred storage type
*/
#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Storage type inconsistent, Provided = " << common::stype_string((type_array)[index]) \
<< ',' << " inferred storage type = " << common::stype_string(type); \
throw ::mxnet::op::InferStorageTypeError(os.str(), index); \
} \
}
/*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param type_array the type array to store the result
* \param index the index of in the array
* \param type the inferred dispatch type
*/
#define DISPATCH_MODE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!::mxnet::op::dispatch_mode_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Dispatch mode inconsistent, Provided = " \
<< common::dispatch_mode_string((type_array)[index]) << ',' \
<< " inferred mode = " << common::dispatch_mode_string(type); \
throw ::mxnet::op::InferStorageTypeError(os.str(), index); \
} \
}
/*!
* \brief macro check if type is the same as expected.
* \param type the type to be checked
* \param expected the expected type
*/
#define UNIFORM_TYPE_CHECK(type, expected, arg) \
{ \
CHECK_EQ(type, expected) << "This layer requires uniform type. " \
<< "Expected '" << ::mxnet::op::type_string(expected) \
<< "' v.s. given '" << ::mxnet::op::type_string(type) << "' at '" \
<< arg << "'"; \
}
// helper macro to implement bind dispatch
#if MXNET_USE_CUDA
#define DO_BIND_DISPATCH(Method, ...) \
if (ctx.dev_mask() == cpu::kDevMask) { \
return Method<cpu>(__VA_ARGS__); \
} else { \
return Method<gpu>(__VA_ARGS__); \
}
#else
#define DO_BIND_DISPATCH(Method, ...) \
if (ctx.dev_mask() == cpu::kDevMask) { \
return Method<cpu>(__VA_ARGS__); \
} else { \
LOG(FATAL) << "GPU is not enabled"; \
return nullptr; \
}
#endif
/*! \brief assign stype to target_stype, if successful,
* assign dispatch_mode to target_dispatch
*/
inline bool storage_type_assign(int* stype,
const NDArrayStorageType target_stype,
DispatchMode* dispatch,
const DispatchMode target_dispatch) {
if (type_assign(stype, target_stype)) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch, 0, target_dispatch);
return true;
}
return false;
}
/*! \brief assign the stype vector to target_stype, if successful,
* assign dispatch_mode to target_dispatch
*/
inline bool storage_type_assign(StorageTypeVector* stypes,
const NDArrayStorageType target_stype,
DispatchMode* dispatch,
const DispatchMode target_dispatch) {
CHECK_GT(stypes->size(), 0);
bool success = true;
for (int& stype : *stypes) {
if (!type_assign(&stype, target_stype)) {
success = false;
}
}
if (success) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch, 0, target_dispatch);
}
return success;
}
/*! \brief update the stype vector to default storage and dispatch_mode to fallback
*/
inline bool dispatch_fallback(StorageTypeVector* stypes, DispatchMode* dispatch) {
for (auto& stype : *stypes) {
type_assign(&stype, kDefaultStorage);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch, 0, DispatchMode::kFComputeFallback);
return true;
}
inline std::vector<nnvm::NodeEntry> CreateNodeEntries(
nnvm::ObjectPtr pNode,
const std::vector<nnvm::NodeEntry>* pOgrads = nullptr,
const std::vector<nnvm::NodeEntry>* pInputs = nullptr) {
if (pOgrads)
pNode->inputs.insert(pNode->inputs.end(), pOgrads->begin(), pOgrads->end());
if (pInputs)
pNode->inputs.insert(pNode->inputs.end(), pInputs->begin(), pInputs->end());
if (!pNode->is_variable()) {
CHECK_EQ(pNode->num_inputs(), pNode->inputs.size())
<< "Number of inputs to operator " << pNode->op()->name << " (" << pNode->num_inputs()
<< ") does not match the actual number of inputs provided to operator " << pNode->attrs.name
<< " (" << pNode->inputs.size() << ").";
}
std::vector<nnvm::NodeEntry> ret;
for (uint32_t i = 0; i < pNode->num_outputs(); ++i)
ret.emplace_back(nnvm::NodeEntry{pNode, i, 0});
return ret;
}
// make a new node with operator op_name. Inputs are not filled.
inline nnvm::ObjectPtr MakeNode(const char* op_name,
const std::string& name,
std::vector<nnvm::NodeEntry> const* inputs = nullptr,
std::unordered_map<std::string, std::string> const* dict = nullptr,
nnvm::ObjectPtr const* fwd_node = nullptr) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = name;
if (dict != nullptr)
p->attrs.dict = *dict;
if (inputs != nullptr)
p->inputs = *inputs;
if (fwd_node != nullptr) {
p->control_deps.emplace_back(*fwd_node);
}
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
if (inputs != nullptr) {
CHECK_EQ(p->num_inputs(), p->inputs.size())
<< "Number of inputs to operator " << op_name << " (" << p->num_inputs()
<< ") does not match the actual number of inputs provided to operator " << name << " ("
<< p->inputs.size() << ").";
}
return p;
}
inline nnvm::ObjectPtr MakeNode(const char* op_name,
const std::string& name,
const std::vector<nnvm::NodeEntry>& inputs,
std::unordered_map<std::string, std::string> const* dict,
nnvm::ObjectPtr const* fwd_node) {
return MakeNode(op_name, name, &inputs, dict, fwd_node);
}
// quick helper to make node
inline std::vector<nnvm::NodeEntry> MakeGradNode(
const char* op_name,
const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& inputs,
const std::unordered_map<std::string, std::string>& dict) {
auto p = MakeNode(op_name, n->attrs.name + "_backward", &inputs, &dict, &n);
return CreateNodeEntries(p);
}
// quick helper to make gradient nodes that simply pass back zero. could be used in output ops.
inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> ret;
for (uint32_t i = 0; i < n->num_inputs(); ++i) {
std::ostringstream os;
if (1 == n->num_inputs()) {
os << n->attrs.name << "_backward";
} else {
os << n->attrs.name << "_in" << i << "_backward";
}
ret.emplace_back(MakeNode("zeros_like", os.str(), {n->inputs[i]}, nullptr, &n));
}
return ret;
}
// check whether all output grads are zero.
inline bool CheckGradAllZero(const std::vector<nnvm::NodeEntry>& ograds) {
static const auto zero_op = nnvm::Op::Get("_zeros");
static const auto zero_like_op = nnvm::Op::Get("zeros_like");
if (ograds.empty())
return false;
for (const auto& grad : ograds) {
if (!grad.node)
return false;
if (grad.node->op() != zero_op && grad.node->op() != zero_like_op)
return false;
}
return true;
}
// make gradient node that doesn't add to objective.
// i.e. igrads are always zero when ograds are zero.
inline std::vector<nnvm::NodeEntry> MakeNonlossGradNode(
const char* op_name,
const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds,
const std::vector<nnvm::NodeEntry>& inputs,
const std::unordered_map<std::string, std::string>& dict) {
if (CheckGradAllZero(ograds))
return MakeZeroGradNodes(n, ograds);
auto p = MakeNode(op_name, n->attrs.name + "_backward", nullptr, &dict, &n);
return CreateNodeEntries(p, &ograds, &inputs);
}
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template <typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
PType param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
attrs->parsed = std::move(param);
}
inline void CheckAllRowsPresent(const NDArray& arr,
const std::string& func,
const std::string& param) {
if (arr.storage_type() == kRowSparseStorage) {
CHECK(arr.storage_shape()[0] == arr.shape()[0])
<< func << " for RowSparse " << param << " is only implemented for "
<< "RowSparse " << param << " with all rows containing non-zeros. "
<< "Expects " << param << ".data.shape[0] (" << arr.storage_shape()[0] << ") == " << param
<< ".shape[0] (" << arr.shape()[0] << ").";
} else {
CHECK(arr.storage_type() == kDefaultStorage);
}
}
inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using common::operator_string;
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
class OpSignature {
std::vector<int64_t> eles;
uint64_t hash;
public:
OpSignature() {
hash = 0;
}
explicit OpSignature(uint64_t hash) {
this->hash = hash;
}
/*
* This is to reserve space for the vector.
*/
void Reserve(size_t num) {
eles.reserve(num);
}
/*
* We provide different methods to add signature to an op.
* For operations, such as convolutin and fully connected, which determines
* the optimal data layout for the op, we only need to use the shape and data
* type to sign the op. For other operations, such as activation, which uses
* whatever layout in the input array, we have to use the shape, the data type
* and the layout to sign the op.
*/
#if MXNET_USE_ONEDNN == 1
void AddSign(const dnnl::memory::desc& desc) {
hash = hash * 2 + desc.data.format_kind;
eles.push_back(desc.data.format_kind);
hash = hash * 2 + desc.data.data_type;
eles.push_back(desc.data.data_type);
for (int i = 0; i < desc.data.ndims; i++) {
hash = hash * 2 + desc.data.dims[i];
eles.push_back(desc.data.dims[i]);
}
switch (desc.data.format_kind) {
case dnnl_blocked:
hash = hash * 2 + desc.data.ndims;
eles.push_back(desc.data.ndims);
for (int i = 0; i < desc.data.ndims; i++) {
hash = hash * 2 + desc.data.format_desc.blocking.strides[i];
eles.push_back(desc.data.format_desc.blocking.strides[i]);
}
hash = hash * 2 + desc.data.format_desc.blocking.inner_nblks;
eles.push_back(desc.data.format_desc.blocking.inner_nblks);
for (int i = 0; i < desc.data.format_desc.blocking.inner_nblks; i++) {
hash = hash * 2 + desc.data.format_desc.blocking.inner_blks[i];
hash = hash * 2 + desc.data.format_desc.blocking.inner_idxs[i];
eles.push_back(desc.data.format_desc.blocking.inner_blks[i]);
eles.push_back(desc.data.format_desc.blocking.inner_idxs[i]);
}
break;
case dnnl_format_kind_wino:
hash = hash * 2 + desc.data.format_desc.wino_desc.wino_format;
eles.push_back(desc.data.format_desc.wino_desc.wino_format);
break;
case dnnl_format_kind_rnn_packed:
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.format;
eles.push_back(desc.data.format_desc.rnn_packed_desc.format);
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n_parts;
eles.push_back(desc.data.format_desc.rnn_packed_desc.n_parts);
for (int i = 0; i < desc.data.format_desc.rnn_packed_desc.n_parts; ++i) {
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.parts[i];
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.part_pack_size[i];
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.pack_part[i];
eles.push_back(desc.data.format_desc.rnn_packed_desc.parts[i]);
eles.push_back(desc.data.format_desc.rnn_packed_desc.part_pack_size[i]);
eles.push_back(desc.data.format_desc.rnn_packed_desc.pack_part[i]);
}
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n;
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.ldb;
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.offset_compensation;
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.size;
eles.push_back(desc.data.format_desc.rnn_packed_desc.n);
eles.push_back(desc.data.format_desc.rnn_packed_desc.ldb);
eles.push_back(desc.data.format_desc.rnn_packed_desc.offset_compensation);
eles.push_back(desc.data.format_desc.rnn_packed_desc.size);
break;
default:
// nothing need to add
break;
}
}
void AddSign(const dnnl::memory& mem) {
auto desc = mem.get_desc();
AddSign(desc);
}
void AddSign(const std::vector<dnnl::memory::desc>& arrs) {
for (auto& arr : arrs) {
AddSign(arr);
}
}
#endif
void AddSign(const std::string& s) {
uint64_t key = static_cast<uint64_t>(std::hash<std::string>{}(s));
eles.push_back(key);
}
void AddSign(const std::vector<NDArray>& arrs) {
for (auto& arr : arrs) {
AddSign(arr);
}
}
void AddSign(const NDArray& arr) {
#if MXNET_USE_ONEDNN == 1
if (arr.IsDNNLData()) {
AddSign(*(arr.GetDNNLData()));
} else {
#endif
hash = hash * 2 + arr.dtype();
eles.push_back(arr.dtype());
AddSign(arr.shape());
#if MXNET_USE_ONEDNN == 1
}
#endif
}
void AddSign(const mxnet::ShapeVector& shapes) {
for (auto& shape : shapes) {
AddSign(shape);
}
}
void AddSign(const mxnet::TShape& shape) {
for (int i = 0; i < shape.ndim(); i++) {
hash = hash * 2 + shape[i];
eles.push_back(shape[i]);
}
}
void AddSign(int val) {
hash = hash * 2 + val;
eles.push_back(val);
}
void AddSign(const std::vector<float>& vec) {
for (auto& val : vec) {
AddSign(val);
}
}
void AddSign(float val) {
hash = dmlc::HashCombine(hash, val);
eles.push_back(val);
}
bool operator==(const OpSignature& sign) const {
if (hash != sign.hash)
return false;
if (eles.size() != sign.eles.size())
return false;
for (size_t i = 0; i < eles.size(); i++)
if (eles[i] != sign.eles[i])
return false;
return true;
}
uint64_t GetHash() const {
return hash;
}
};
struct OpHash {
size_t operator()(const OpSignature& sign) const {
return sign.GetHash();
}
};
template <typename ParamType>
class ParamOpSign : public OpSignature {
const ParamType param;
static size_t hash(const ParamType& param) {
std::hash<ParamType> fn;
return fn(param);
}
public:
explicit ParamOpSign(const ParamType& _param) : OpSignature(hash(_param)), param(_param) {}
bool operator==(const ParamOpSign<ParamType>& sign) const {
const OpSignature& this_upper = *this;
const OpSignature& other_upper = sign;
return this_upper == other_upper && param == sign.param;
}
};
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_OPERATOR_COMMON_H_