blob: 6c2c6b2cce6975ec042112a0b4113bd473731524 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file op_common.h
* \brief A set of utilities and common functionality
* for relay ops.
*/
#ifndef TVM_RELAY_OP_OP_COMMON_H_
#define TVM_RELAY_OP_OP_COMMON_H_
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "../transforms/infer_layout_utils.h"
#include "type_relations.h"
namespace tvm {
namespace relay {
/*! Quick helper macro
* - Expose a positional make function to construct the node.
* - Register op to the registry.
*
* We make the decision to always only expose positional argument.
* We will do rewrapping in the frontend to support language
* sugars such as keyword arguments and default value.
* \param OpName the name of registry.
*/
#define RELAY_REGISTER_UNARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \
static const Op& op = Op::Get(OpName); \
return Call(op, {data}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.") \
.add_type_rel("Identity", IdentityRel) \
.set_attr<TOpPattern>("TOpPattern", kElemWise) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
/*! Quick helper macro
* - Expose a positional make function to construct the node.
* - Register op to the registry.
*
* We make the decision to always only expose positional argument.
* We will do rewrapping in the frontend to support language
* sugars such as keyword arguments and default value.
*
* \param OpName the name of registry.
*/
#define RELAY_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return Call(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout)
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return Call(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout)
/*! \brief A helper class for matching and rewriting operators. */
template <typename R>
class OpMatch {
public:
using MatchFunc =
std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;
/*! \brief Match an operator with the given name.
* \param op_name The name of the operator to match.
* \param func The function to execute when it matches.
* \return A self-reference for builder style API.
*/
inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
auto op = Op::Get(op_name);
match_map_.insert({op, func});
return *this;
}
/*! \brief Rewrite a call operation based on the operator and the registered
* match functions.
* \param call The call to rewrite.
* \return The result of rewriting.
*/
inline R operator()(const Call& call) {
auto it = match_map_.find(Downcast<Op>(call->op));
if (it != match_map_.end()) {
return it->second(call->args, call->attrs, call->type_args);
} else {
if (default_ != nullptr) {
return default_(call->args, call->attrs, call->type_args);
} else {
LOG(FATAL) << "unexpected operation " << call->op;
}
}
}
private:
/*! \brief The match function map. */
std::unordered_map<Op, MatchFunc, ObjectPtrHash, ObjectPtrEqual> match_map_;
/*! \brief An optional default case. */
MatchFunc default_;
};
/*! \brief A utility function to get padding width from a 1 or 2 ints tuple. */
inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_w = padding[0] * 2;
} else if (padding.size() == 2) {
*pad_w = padding[0] + padding[1];
} else {
ICHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size();
}
}
/*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */
inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h,
IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 2) {
*pad_h = padding[0] * 2;
*pad_w = padding[1] * 2;
} else if (padding.size() == 4) {
*pad_h = padding[0] + padding[2];
*pad_w = padding[1] + padding[3];
} else {
ICHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size();
}
}
/*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */
inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d,
IndexExpr* pad_h, IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_d = padding[0] * 2;
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 3) {
*pad_d = padding[0] * 2;
*pad_h = padding[1] * 2;
*pad_w = padding[2] * 2;
} else if (padding.size() == 6) {
*pad_d = padding[0] + padding[3];
*pad_h = padding[1] + padding[4];
*pad_w = padding[2] + padding[5];
} else {
ICHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size();
}
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_OP_COMMON_H_