blob: 84095368886eb7d902c14dccb2620d8d7083a9af [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file nnvm/op_attr_types.h
* \brief Data structures that can appear in operator attributes.
*/
#ifndef NNVM_OP_ATTR_TYPES_H_
#define NNVM_OP_ATTR_TYPES_H_
#include <functional>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "base.h"
#include "layout.h"
#include "node.h"
#include "tuple.h"
namespace nnvm {
// These types are optional attributes in each operator.
// Each attribute can be required by some passes.
/*!
* \brief Return list of input arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using FListInputNames = std::function<std::vector<std::string>(const NodeAttrs& attrs)>;
/*!
* \brief Return number of visible outputs by the user.
*
* \param attrs The attributes of the node.
*
* \note Register under "FNumVisibleOutputs", default not registered.
* This can be used to hide certain output from the user,
* but the additional outputs can be used to pass information from
* forward to gradient pass.
*/
using FNumVisibleOutputs = std::function<uint32_t(const NodeAttrs& attrs)>;
/*!
* \brief Return list of output arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using FListOutputNames = std::function<std::vector<std::string>(const NodeAttrs& attrs)>;
/*!
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \return list of input indices it mutates.
*
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using FMutateInputs = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
/*!
* \brief Inference function of certain type.
* \tparam AttrType The type of the attribute to be infered.
* \return whether all attributes are inferred.
*/
template <typename AttrType>
using FInferNodeEntryAttr = std::function<bool(
const NodeAttrs& attrs, std::vector<AttrType>* in_attrs, std::vector<AttrType>* out_attrs)>;
/*!
* \brief Get attribute dictionary from node.
*
* \param attrs The attributes of the node.
* \return The attribute dict.
* \note Register under "FUpdateAttrDict"
*/
using FGetAttrDict =
std::function<std::unordered_map<std::string, std::string>(const NodeAttrs& attrs)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
*
* \note Register under "FInferShape",
* by default do not update any shapes.
*
* FInferShape is needed by shape inference
*/
using FInferShape = FInferNodeEntryAttr<TShape>;
/*!
* \brief Type inference function.
* Update the type given the known type information.
*
* \note Register under "FInferType",
* by default set all the output types to 0.
*/
using FInferType = FInferNodeEntryAttr<int>;
/*!
* \brief Whether this op is an explicit backward operator,
* If TIsBackward is true:
* - The first control_deps of the node points to the corresponding forward operator.
*
* \note Register under "TIsBackward"
* This enables easier shape/type inference for backward operators.
*/
using TIsBackward = bool;
/*!
* \brief Whether this op is a ghost node.
* If TIsGhost is true:
* - The node with this op will not be visible in the indexed graph.
*
* \note Register under "TIsGhost"
* This enables shape/type inference for backward nodes when
* fusion is present.
*/
using TIsGhost = bool;
/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
using FInplaceOption = std::function<std::vector<std::pair<int, int> >(const NodeAttrs& attrs)>;
/*!
* \brief Get if the inplace option is an identity
* This function enables inplace optimization even when input reference count
* is greater than one.
* \param attrs The attributes of the node
* \return list of bool indicating whether corresponding pair from FInplaceOption
* is an identity
*
* \note Register under "FInplaceIdentity", by default no identities.
*/
using FInplaceIdentity = std::function<std::vector<bool>(const NodeAttrs& attrs)>;
/*!
* \brief Get list of inputs in the op whose content are actually not used by the operator
* These are dummy input that can be used for example in zeros_like, ones_like.
*
* \param attrs The attributes of the node
* \return list input index that are not used by the operator.
*
* \note Register under "FIgnoreInputs".
*/
using FIgnoreInputs = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(const ObjectPtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
/*!
* \brief Set the attributes of input variable.
* Usually used for setting initialization or weight decay.
* \param attrs The attributes of this node.
* \param var the input variable
* \param index index of var in all inputs
*/
using FSetInputVarAttrOnCompose =
std::function<void(const NodeAttrs& attrs, ObjectPtr var, const int index)>;
/*!
* \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param ilayouts Given the input layouts produced by ancestor nodes,
* it should be filled by layouts that the node requests.
* If the requested layout is different from what ancestor produces,
* a __layout_transform__ operator will be inserted automatically.
* \param last_ilayouts The input layouts requested by the node
* at the last infer pass (if any).
* This can be useful when an operator wants to keep
* the input layout the same as the original one.
* For example, after the pass of AlterOpLayout,
* transpose(input, axis=[1, 2, 3, 0]) may receive an input of NCHW16c layout,
* with which it cannot calculate with axis=[1, 2, 3, 0].
* Last input layouts allow it to know what the layout it originally inferred,
* i.e., the layout in the imported model.
* \param olayouts Inferred output layouts.
* \return success flag.
*/
using FCorrectLayout =
std::function<bool(const NodeAttrs& attrs, std::vector<Layout>* ilayouts,
const std::vector<Layout>* last_ilayouts, std::vector<Layout>* olayouts)>;
/*!
* \brief Get a list of inputs that represent graphs instead of data.
* Normally, input symbols are considered as data to the operator. However,
* control flow operators and high-order functions need to interpret symbols
* as graphs.
* \param attrs The attributes of this node.
* \return a list of input index that are interpreted as symbols by the operator.
*
* \note Register under "FInputGraph".
*/
using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
} // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_