blob: 2245db0dbb93d64a58fa3a77bc750c96c8fa3d4a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file operator.h
* \brief Operator interface of mxnet.
* \author Naiyan Wang
*/
#ifndef MXNET_OPERATOR_H_
#define MXNET_OPERATOR_H_
#include <dmlc/base.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <nnvm/node.h>
#include <vector>
#include <map>
#include <string>
#include <utility>
#include "./base.h"
#include "./resource.h"
#include "./op_attr_types.h"
namespace mxnet {
/*!
* \brief Operator interface.
* Operator defines basic operation unit of optimized computation graph in mxnet.
* This interface relies on pre-allocated memory in TBlob, the caller need to set
* the memory region in TBlob correctly before calling Forward and Backward.
*
* Operator is generated by OperatorProperty.
* To add new operator(aka. layers of neural nets) to mxnet, developer need to create
* a new OperatorProperty and its corresponding Operator.
*
* \sa TBlob, TShape, OperatorProperty
*/
class Operator {
public:
/*! \brief destructor */
virtual ~Operator() {}
/*!
* \brief perform a forward operation of Operator, save the output to TBlob.
* \param ctx runtime context available to this call
* \param in_data array of input data, it is const
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \param aux_states Auxiliary states of operator. Normally operator doesn't
* need, epecial case like Batch Norm requires.
* \sa OpReqType, OpContext
*/
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) = 0;
/*!
* \brief Perform a Backward Operation, write gradient to the in_grad.
*
* \note
* Convention:
* out_grad.size() == OperatorProperty.NumVisibleOutputs()
* out_data.size() == OperatorProperty.NumOutputs()
* out_data can contain additional invisible returns that remembers the
* state carried from the Forward pass. For example mask in the dropout.
* The gradients are passed from visible returns in this function.
*
* \par
* Not all the TBlobs in the arguments will be available
* if you override the DeclareBackwardDependency of corresponding OperatorProperty class.
* Only the dependencies you declared will be available at corresponding position,
* the rest of the parameters are simply dummy where you will get a nullptr.
* You will be safe if you use the default DeclareBackwardDependency.
* But only declare what you need will give engine more chance for optimization.
*
* \param ctx runtime context available to this call
* \param out_grad the gradient value we get from of the Operator.
* \param in_data the array of input data.
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \param aux_states Auxiliary states of operator. Normally operator doesn't need
* \sa OperatorProperty, OpReqType, OpContext
*/
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
LOG(FATAL) << "Backward is not implemented";
}
/*! \return [Deprecated] execution type of the operator */
virtual ExecType exec_type() const final { // NOLINT(*) exec_type has been moved to OperatorProperty
return ExecType::kSync;
}
};
#if DMLC_USE_CXX11
// OperatorProperty allows C++11, while Operator do not rely on it.
/*!
* \brief OperatorProperty is a object that stores all information about Operator.
* It also contains method to generate context(device) specific operators.
*
* It also contains various functions that can be optimally overriden to
* provide optimization chance for computation engine.
*/
class OperatorProperty {
public:
/*!
* \brief virtual destructor
*/
virtual ~OperatorProperty() {}
/*!
* \brief Initialize the Operator by setting the parameters
* This function need to be called before all other functions.
* \param kwargs the keyword arguments parameters
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*!
* \brief Get a map representation of internal parameters.
* This can be used by Init to recover the state of OperatorProperty.
*/
virtual std::map<std::string, std::string> GetParams() const = 0;
/*!
* \brief Get input arguments of the Operator.
* \return vector of arguments.
*/
virtual std::vector<std::string> ListArguments() const {
return {"data"};
}
/*!
* \brief Get name of output values of Operator
* \return name of output values.
*/
virtual std::vector<std::string> ListOutputs() const {
return {"output"};
}
/*!
* \brief Get name of auxiliary states of Operator
* \return name of return values.
*/
virtual std::vector<std::string> ListAuxiliaryStates() const {
return {};
}
/*! \return number of real return values of the Operator */
virtual int NumOutputs() const {
return this->ListOutputs().size();
}
/*!
* \brief get number of visible return values during Symbol creation.
* If NumVisibleOutputs() = k, and NumOutputs() = n.
* The first k returns will be presented in the resulting symbol.
*
* The rest of the returns can be used for auxiliary states for Backward.
* For example, Dropout will return [data, mask], with NumVisibleOutputs() == 1.
* So when user call sym = Dropout(input), only data is presented in sym.
* But all the returns will be presented in out_data parameter of Backward if requested.
*
* \return number of default return values
*/
virtual int NumVisibleOutputs() const {
return NumOutputs();
}
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be inferred
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \param aux_shape the shape of auxiliary states of the operator
* InferShape will modify the vector to fill output TShape
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const = 0;
/*!
* \brief infer the data types of outputs and unknown input arguments
* \param in_type the type of input arguments of the operator
* this should be of same length as the vector returned by DescribeArgs
* in_type allows unknown elements, which are checked by type.ndim() == 0.
* For unknown types, Infertype will try to fill in the correct type in in_type
* For known types, Infertype will check type consistency
*
* common practice: set the type of data input, and usually weight's type can be inferred
*
* \param out_type the type of outputs of the operator
* Infertype will modify the vector to fill output Ttype
* \param aux_type the type of auxiliary states of the operator
* Infertype will modify the vector to fill output Ttype
* \return true if the type inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_types are inconsistent.
*/
virtual bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const {
CHECK_LE(in_type->size(), this->ListArguments().size());
int n_in = this->ListArguments().size();
for (unsigned i = 0; i < in_type->size(); ++i) {
CHECK(in_type->at(i) == mshadow::default_type_flag ||
in_type->at(i) == -1) << "Unsupported data type " << in_type->at(i);
}
in_type->clear();
for (int i = 0; i < n_in; ++i ) in_type->push_back(mshadow::default_type_flag);
int n_out = this->ListOutputs().size();
out_type->clear();
for (int i = 0; i < n_out; ++i ) out_type->push_back(mshadow::default_type_flag);
int n_aux = this->ListAuxiliaryStates().size();
aux_type->clear();
for (int i = 0; i < n_aux; ++i ) aux_type->push_back(mshadow::default_type_flag);
return true;
}
/*!
* \brief Copy this OperatorProperty.
* \return a pointer to the copied OperatorProperty
*/
virtual OperatorProperty* Copy() const = 0;
/*!
* \brief Create a Operator on specific context
*/
virtual Operator* CreateOperator(Context ctx) const = 0;
/*!
* \brief Create a Operator on specific context and input shape/type
* \param ctx context of this operator
* \param in_shape shape of the input ndarrays
* \param in_type dtype of the input ndarrays
* \return the created operator
*/
virtual Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<int> out_type, aux_type;
std::vector<TShape> out_shape, aux_shape;
out_type.resize(this->ListOutputs().size());
out_shape.resize(this->ListOutputs().size());
aux_type.resize(this->ListAuxiliaryStates().size());
aux_shape.resize(this->ListAuxiliaryStates().size());
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
return CreateOperator(ctx);
}
/*!
* \brief return the type string of the Operator
* subclasses override this function.
* \return The type string.
*/
virtual std::string TypeString() const = 0;
//--------------------------------------------------------
// All the below functions are optional to override.
//--------------------------------------------------------
/*!
* \brief Declare additional resource required in forward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \param in_shape The input shape to the operator, corresponds to shapes of in_data.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> ForwardResource(
const std::vector<TShape> &in_shape) const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Declare additional resource required in backward pass.
* These additional resources will be presented in OpContext.requested
* in the same order of the returned Resource.
* \param in_shape The input shape to the operator, corresponds to shapes of in_data.
* \return Additional resource request
*/
virtual std::vector<ResourceRequest> BackwardResource(
const std::vector<TShape> &in_shape) const {
return std::vector<ResourceRequest>();
}
/*!
* \brief Declare the input requirement of Backward pass.
*
* Only the returned list of variables will be used in Backward.
* This function is used for memory optimization.
* It is advised to override and only return what is actually needed.
* If this function is not overriden, all the variables will be valid in Backward.
*
* \code
* // The following code declares Backward need out_grad[0], in_data[0],in_data[1]
* vector<int> BackwardInputs(const vector<int> &out_grad,
* const vector<int> &in_data,
* const vector<int> &out_data) const {
* return {out_grad[0], in_data[0], in_data[1]};
* }
* \endcode
* \param out_grad gradient of outputs in backward pass.
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \return an integer vector indicating the input requirments
* \sa BackwardInputs
*/
virtual std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const {
// By default requires to see all the things.
// remember to override this function to get a better performance.
std::vector<int> ret = out_grad;
ret.insert(ret.end(), in_data.begin(), in_data.end());
ret.insert(ret.end(), out_data.begin(), out_data.end());
return ret;
}
/*!
* \brief Get possible forward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* The reason for void* type in the out_data is to distinguish the order
* of mappings between the two, compiler will report error when
* in_data and out_data's order in the pair get reversed.
*
* \code
* // The following code says out_data[0] can share data with in_data[0]
* vector<pair<int, void*> > ForwardInplaceOption(const vector<int> &in_data,
* const vector<void*> &out_data) const {
* return {{in_data[0], out_data[0]}};
* }
* \endcode
* \param in_data The input data in forward pass.
* \param out_data The output data in forward pass.
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, void*> > ForwardInplaceOption(
const std::vector<int> &in_data,
const std::vector<void*> &out_data) const {
return std::vector<std::pair<int, void*> >();
}
/*!
* \brief Get possible backward inplace options.
* This function enables optimization to reuse memory of inputs in output.
* Only override when necessary, by default in-place is disabled.
*
* The reason for void* type in the in_grad is to distinguish the order
* of mappings between the two, compiler will report error when
* in_data and out_data's order in the pair get reversed.
*
* \code
* // The following code says in_grad[0] can share data with in_data[0]
* vector<pair<int,int> > BackwardInplaceOption(
* const std::vector<int> &out_grad,
* const std::vector<int> &in_data,
* const std::vector<int> &out_data,
* const std::vector<int> &in_grad) const {
* return {in_data[0], in_grad[0]}};
* }
* \endcode
* \param in_data The input data in forward pass.
* \param out_data The output data in forward pass.
* \param in_grad Gradient of inputs in backward pass.
* \param out_grad Gradient of outputs in backward pass.
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*/
virtual std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const {
return std::vector<std::pair<int, void*> >();
}
/*!
* \brief Get Backward Input Dependency for generic types of data.
* Normally T can be pointer of Symbol::DataEntry, or NDArray.
* This function will select the result list of T according to DeclareBackwardDependency.
*
* \param in_data the input data in forward pass.
* \param out_data the output data in forward pass.
* \param out_grad gradient of outputs in backward pass.
* \tparam T the generic type parameter.
* \return vector of inputs the Backward Operation depends on.
* \sa DeclareBackwardDependency
*/
template<typename T>
inline std::vector<T> BackwardInputs(const std::vector<T> &out_grad,
const std::vector<T> &in_data,
const std::vector<T> &out_data) const {
int counter = 0;
std::vector<int> out_grad_index(out_grad.size());
std::vector<int> in_data_index(in_data.size());
std::vector<int> out_data_index(out_data.size());
for (size_t i = 0; i < out_grad_index.size(); ++i) {
out_grad_index[i] = counter++;
}
for (size_t i = 0; i < in_data_index.size(); ++i) {
in_data_index[i] = counter++;
}
for (size_t i = 0; i < out_data_index.size(); ++i) {
out_data_index[i] = counter++;
}
std::vector<T> all_data;
all_data.insert(all_data.end(), out_grad.begin(), out_grad.end());
all_data.insert(all_data.end(), in_data.begin(), in_data.end());
all_data.insert(all_data.end(), out_data.begin(), out_data.end());
std::vector<int> ret_index = this->DeclareBackwardDependency(
out_grad_index, in_data_index, out_data_index);
std::vector<T> ret(ret_index.size());
for (size_t i = 0; i < ret_index.size(); ++i) {
ret[i] = all_data[ret_index[i]];
}
return ret;
}
/*!
* \brief create OperatorProperty
* \param type_name the type string of the OperatorProperty
* \return a new constructed OperatorProperty
*/
static OperatorProperty *Create(const char* type_name);
/*! \return execution type of the operator */
virtual ExecType exec_type() const {
return ExecType::kSync;
}
};
/*! \brief typedef the factory function of operator property */
typedef std::function<OperatorProperty *()> OperatorPropertyFactory;
/*!
* \brief Registry entry for OperatorProperty factory functions.
*/
struct OperatorPropertyReg
: public dmlc::FunctionRegEntryBase<OperatorPropertyReg,
OperatorPropertyFactory> {
/*!
* \brief Set key_var_num_args
* When this is set, the API caller is required to pass in a
* argument with key=key_num_args.c_str(), and value=num_args.
* num_args is number of positional argument when calling the function.
*
* This is used to pass in length of positional arguments
* for operators that can take variable length of input.
* Most operators do not need to set this property.
*
* \param key the key name to be set
*/
inline OperatorPropertyReg& set_key_var_num_args(const std::string &key) { // NOLINT(*)
this->key_var_num_args = key;
return *this;
}
/*!
* \brief Check if TypeString of the type matches the registered name
*/
inline OperatorPropertyReg& check_name() {
OperatorProperty *p = this->body();
std::string type = p->TypeString();
delete p;
CHECK_EQ(this->name, type)
<< "Register Name and TypeString mismatch, name=\"" << this->name << "\","
<< " but TypeString=\"" << type <<"\"";
return *this;
}
/*! \brief The key num_args name. */
std::string key_var_num_args;
};
//---------------------------------------------------------------------------------
// The following part are API Registration of Operators
// See also MXNET_REGISTER_SIMPLE_OP in operator_util.h for registering simple ops.
//---------------------------------------------------------------------------------
/*!
* \brief Macro to register OperatorProperty
*
* \code
* // example of registering a fully connected operator
* REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedOpProp)
* .describe("Fully connected layer");
*
* \endcode
*/
#define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
.set_body([]() { return new OperatorPropertyType(); }) \
.set_return_type("NDArray-or-Symbol") \
.check_name()
#endif // DMLC_USE_CXX11
} // namespace mxnet
#endif // MXNET_OPERATOR_H_