blob: 5a8094dbac80bbe2e148c38717f7d7c520552eaa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file op_attr_types.h
* \brief Additional operator attributes
* beside the ones provided by NNVM
*/
#ifndef MXNET_OP_ATTR_TYPES_H_
#define MXNET_OP_ATTR_TYPES_H_
#include <mshadow/tensor.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include <functional>
#include <string>
#include "./base.h"
#include "./ndarray.h"
#include "./engine.h"
#include "./resource.h"
namespace mxnet {
using nnvm::NodeAttrs;
/*! \brief operation request type to Forward and Backward */
enum OpReqType {
/*! \brief no operation, do not write anything */
kNullOp,
/*! \brief write gradient to provided space */
kWriteTo,
/*!
* \brief perform an inplace write,
* This option only happen when
* Target shares memory with one of input arguments.
*/
kWriteInplace,
/*! \brief add to the provided space */
kAddTo
};
/*!
* \brief All the possible information needed by Operator.Forward and Backward
* This is the superset of RunContext.
* We use this data structure to bookkeep everything needed by Forward and Backward.
* \sa Resource
*/
struct OpContext {
/*! \brief whether there is a backward phase to compute gradients. */
bool need_grad;
/*! \brief whether it is training phase */
bool is_train;
/*! \brief RunContext related resources */
RunContext run_ctx;
/*! \brief the callback when operation completes, used by asynchronize ops */
engine::CallbackOnComplete async_on_complete;
/*! \brief Resources requested by the operator */
std::vector<Resource> requested;
/*!
* \brief get mshadow stream from Context
* \return the mshadow stream
* \tparam xpu the device type of the stream
*/
template<typename xpu>
inline mshadow::Stream<xpu>* get_stream() const {
return run_ctx.get_stream<xpu>();
}
#if MXNET_USE_CUDA
/*!
* \brief get auxilary gpu stream auto-syncing object from Context
* \return the aux stream auto-syncing object
*/
inline SyncedGPUAuxStream get_gpu_aux_stream() const {
return run_ctx.get_gpu_aux_stream();
}
#endif
};
/*! \brief the execution type of the operator */
enum class ExecType {
/*! \brief Forward/Backward are synchronous calls */
kSync,
/*!
* \brief Forward/Backward are asynchronous,
* will call OpContext.async_on_complete when operation finishes.
*/
kAsync,
/*!
* \brief Cross device copy operation, this is a special operator that indicates it will copy
* across devices. For example the input and output for this type of operator can potentially
* reside on different devices. In the current implementation, a copy operator is specially
* handled by an executor. This flag is used for special case treatment and future extension of
* different copy ops.
*/
kCrossDeviceCopy,
/*!
* \brief A subgraph execution should happen in the main thread, instead of
* in the execution engine.
*/
kSubgraphExec,
};
/*! \brief the dispatch mode of the operator */
enum class DispatchMode {
kUndefined = -1,
// dispatch on FCompute or FStatefulCompute
kFCompute,
// dispatch on FComputeEx or FStatefulComputeEx, if available
kFComputeEx,
// dispatch on FCompute or FStatefulCompute, and performs storage fallback
kFComputeFallback,
// special dispatch mode for variables
kVariable,
};
/*! \brief the quantization type of the operator */
enum class QuantizeType {
// This operator doesn't support quantization
kNone = 0,
// This operator can get huge benefit from quantization, thus must be quantized
kMust,
// This operator support quantization, but will be decided depending on the connection
kSupport,
};
/*!
* \brief Operator state. This is a pointer type, its content is mutable
* even if OpStatePtr is const.
*/
class OpStatePtr {
public:
/* \brief Create a OpStatePtr with state of type T.
* \param args Arguments passed to T's constructor.
*/
template<typename T, typename... Args>
static OpStatePtr Create(Args&&... args) {
OpStatePtr ret;
auto state = new T(std::forward<Args>(args)...);
auto var = Engine::Get()->NewVariable();
ret.ptr_.reset(
new OpState(var, state),
[](OpState* p) {
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
delete reinterpret_cast<T*>(p->state);
delete p;
});
return ret;
}
/* \brief Get engine variable associated with this state */
engine::VarHandle get_var() const {
return ptr_->var;
}
/* \brief Get state of type T */
template<typename T>
T& get_state() const {
return *reinterpret_cast<T*>(ptr_->state);
}
/* \brief clear state */
void reset() {
ptr_.reset();
}
/* \brief checks whether the managed object is managed only by the current
OpStatePtr instance */
bool unique() const {
return ptr_.unique();
}
/* \brief Whether state is empty */
explicit operator bool() const {
return ptr_ ? true : false;
}
private:
/* \brief state structure */
struct OpState {
engine::VarHandle var;
void* state;
OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
OpState(const OpState& other) = delete;
OpState& operator=(const OpState& other) = delete;
};
/* \brief shared pointer to state */
std::shared_ptr<OpState> ptr_;
};
/*!
* \brief Create a Layer style, forward/backward operator.
* This is easy to write code that contains state.
* OpStatePtr is a pointer type, it's content is mutable even if
* OpStatePtr is constant.
*
*
* This is not the only way to register an op execution function.
* More simpler or specialized operator form can be registered
*
* \note Register under "FCreateLayerOp"
*/
using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shape,
const std::vector<int>& in_type)>;
/*!
* \brief Whether the operator always produces the same
* output given the same input.
* This enables certain optimizations
* like common expression elimination.
*
* \note Register under "THasDeterministicOutput"
*/
using THasDeterministicOutput = bool;
/*!
* \brief Execution mode of this operator.
*/
using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
/*!
* \brief Resiger a compute function for stateful operator.
* OpStatePtr is a pointer type, it's content is mutable even if
* OpStatePtr is constant.
*
* \note Register under "FStatefulCompute<cpu>" and "FStatefulCompute<gpu>"
*/
using FStatefulCompute = std::function<void (const OpStatePtr& state,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs)>;
/*!
* \brief Resiger a compute function for stateful operator using NDArray interface.
* OpStatePtr is a pointer type, it's content is mutable even if
* OpStatePtr is constant.
*
* \note Register under "FStatefulComputeEx<cpu>" and "FStatefulComputeEx<gpu>"
*/
using FStatefulComputeEx = std::function<void (const OpStatePtr& state,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;
/*!
* \brief The resource request from the operator.
* An operator could register ResourceRequestEx, or ResourceRequest, or neither.
*
* \note Register under "FResourceRequest"
*/
using FResourceRequest = std::function<
std::vector<ResourceRequest> (const NodeAttrs& n)>;
/*!
* \brief The resource request from the operator.
* An operator could register ResourceRequestEx, or ResourceRequest, or neither.
* If an operator registers both ResourceRequestEx and ResourceRequest,
* ResourceRequest is ignored.
*
* \note Register under "FResourceRequestEx"
*/
using FResourceRequestEx = std::function<
std::vector<ResourceRequest> (const NodeAttrs& n,
const int dev_mask,
const DispatchMode dispatch_mode)>;
/*!
* \brief Register an operator called as a NDArray function
*
* \note Register under "FNDArrayFunction"
*/
using FNDArrayFunction = std::function<void (const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
std::vector<NDArray>* outputs)>;
/*!
* \brief Register a compute function for simple stateless forward only operator
*
* \note Register under "FCompute<cpu>" and "FCompute<gpu>"
*/
using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs)>;
/*!
* \brief Register an NDArray compute function for simple stateless forward only operator
* \note Register under "FComputeEx<xpu>" and "FComputeEx<xpu>"
* Dispatched only when inferred dispatch_mode is FDispatchComputeEx
*/
using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;
/*!
* \brief Register a storage and dispatch mode inference function based on
* storage types of the inputs and outputs, and the dev_mask for the operator.
*
* \note Register under "FInferStorageType"
*/
using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs)>;
/*!
* \brief Register a quantized node creation function based on the attrs of the node
* \note Register under "FQuantizedOp" for non-quantized operators
*/
using FQuantizable = std::function<QuantizeType (const NodeAttrs& attrs)>;
/*!
* \brief Register a quantized node creation function based on the attrs of the node
* \note Register under "FQuantizedOp" for non-quantized operators
*/
using FQuantizedOp = std::function<nnvm::ObjectPtr (const NodeAttrs& attrs)>;
/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be requantized. This is usually used for the operators
* taking int8 data types while accumulating in int32, e.g. quantized_conv.
* \note Register under "FNeedRequantize" for non-quantized operators
*/
using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized. This is usually used for the quantized operators
* which can handle fp32 inputs directly.
*/
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index,
const std::string quantize_granularity)>;
/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized asymmetrically.
*/
using FNeedAsymQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;
/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be dequantized. This is usually used for the quantized operators
* which can produce fp32 outputs directly.
*/
using FAvoidDequantizeOutput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;
/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
* which need calibration on its input.
*/
using FNeedCalibrateInput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
* which need calibration on its output.
*/
using FNeedCalibrateOutput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
#if MXNET_USE_CUDA
/*!
* \brief Register a function to determine if
* the operator implementation is compatible
* with CUDA graphs. This requires the execution
* to stay the same as long as the shape and type
* of input stays the same.
*/
using FIsCUDAGraphsCompatible = std::function<bool (const NodeAttrs& attrs, const bool is_train)>;
#endif
} // namespace mxnet
#endif // MXNET_OP_ATTR_TYPES_H_