blob: af5b23ed65520e7ecd5c5cac90983aec448ba096 [file] [log] [blame]
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/ir_operator.h
* \brief Common operators defined for Expr.
*
* \note Most of the operator defined here perform simple constant folding
* when the type is int32 or int64 for simplifying the index expressions.
*/
#ifndef TVM_IR_OPERATOR_H_
#define TVM_IR_OPERATOR_H_
#include <algorithm>
#include <type_traits>
#include "expr.h"
#include "ir.h"
namespace tvm {
/*!
* \brief Make a const value with certain data type.
* \param t The target type.
* \param value The input value
* \return the result expression.
* \tparam ValueType The constant value type
*/
template<typename ValueType,
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
inline Expr make_const(Type t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
inline Expr make_zero(Type t);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_true(int lanes = 1) {
return make_const(UInt(1, lanes), 1);
}
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_false(int lanes = 1) {
return make_const(UInt(1, lanes), 0);
}
/*!
* \brief Get x as constant int expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not IntImm.
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::IntImm* op = x.as<ir::IntImm>()) {
return &(op->value);
} else {
return nullptr;
}
}
/*!
* \brief Get x as constant uint expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not UIntImm.
*/
inline const uint64_t* as_const_uint(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImm* op = x.as<ir::UIntImm>()) {
return &(op->value);
} else {
return nullptr;
}
}
/*!
* \brief Check whether x is a constant integer expression.
* \param x The input argument
* \param value the value to be compared against.
* \return whether x is constant expression.
*/
inline bool is_const_int(const Expr& x, int64_t value);
/*!
* \brief Check whether stmt is nop.
* \param stmt The input statement
* \return whether stmt is nop
*/
inline bool is_no_op(const Stmt& stmt);
/*!
* \brief Check whether x is a constant integer 1
* \param x The input argument.
* \note This only return true for integer types.
* \return whether x is constant 1
*/
inline bool is_one(const Expr& x) {
return is_const_int(x, 1);
}
/*!
* \brief Check whether x is a constant integer 0
* \param x The input argument
* \return whether x is constant 0
* \note This only return true for integer types.
*/
inline bool is_zero(const Expr& x) {
return is_const_int(x, 0);
}
/*!
* \brief Check whether x is a constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const(const Expr& x);
/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
*
* \param x The input expression.
* \param shift The output shift if x is power of two.
* \return whether x is constant power of two
*/
TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
/*!
* \brief cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr cast(const Type& t, Expr value);
/*!
* \brief perform reinterpret cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr reinterpret(const Type& t, Expr value);
/*!
* \brief add operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator+(Expr a, Expr b);
/*!
* \brief subtraction operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator-(Expr a, Expr b);
/*!
* \brief negation.
*
* \param a input.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator-(Expr a);
/*!
* \brief multiplication operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator*(Expr a, Expr b);
/*!
* \brief division operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator/(Expr a, Expr b);
/*!
* \brief mod operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator%(Expr a, Expr b);
/*!
* \brief left shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<<(Expr a, Expr b);
/*!
* \brief right shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>>(Expr a, Expr b);
/*!
* \brief greater
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>(Expr a, Expr b);
/*!
* \brief greater_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>=(Expr a, Expr b);
/*!
* \brief less
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<(Expr a, Expr b);
/*!
* \brief less_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<=(Expr a, Expr b);
/*!
* \brief equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator==(Expr a, Expr b);
/*!
* \brief not_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator!=(Expr a, Expr b);
/*!
* \brief and
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator&&(Expr a, Expr b);
/*!
* \brief or
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator||(Expr a, Expr b);
/*!
* \brief not
*
* \param a left operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief take maximum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr max(Expr a, Expr b);
/*!
* \brief take minimum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr min(Expr a, Expr b);
/*!
* \brief take bitwise and of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator&(Expr a, Expr b);
/*!
* \brief take bitwise or of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator|(Expr a, Expr b);
/*!
* \brief take bitwise xor of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator^(Expr a, Expr b);
/*!
* \brief take bitwise negation of two values
*
* \param a the input expression.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator~(Expr a);
/*!
* \brief Conditional expression.
*
* \param cond The condition
* \param true_value The value when results are true.
* \param false_value The value when results are false.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value);
/*!
* \brief Mark condition as likely.
* \param cond The condition
* \return The marked expression.
*/
TVM_DLL Expr likely(Expr cond);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
TVM_DLL Expr pow(Expr x, Expr y);
/*!
* \brief Calculate absolute value of x.
* \param x The input data
*
* \return The aboslute value of input data x
*/
TVM_DLL Expr abs(Expr x);
/*!
* \brief sum of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr max(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr min(Expr source, Array<IterVar> axis);
/*!
* \brief product of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
TVM_DECLARE_INTRIN_UNARY(popcount);
// Implementation details after this
inline bool is_const(const Expr& x) {
if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
return true;
} else if (const auto* op = x.as<ir::Broadcast>()) {
const Expr& val = op->value;
if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) {
return true;
}
}
return false;
}
inline bool is_positive_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
return op->value > 0;
} else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) {
return op->value > 0;
} else {
return false;
}
}
inline bool is_negative_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
return op->value < 0;
} else {
return false;
}
}
inline bool is_const_int(const Expr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImm>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImm>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::Broadcast>()) {
const Expr& val = op->value;
if (const auto* opv = val.as<ir::IntImm>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImm>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
return false;
}
inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) {
return is_const(op->value);
}
return false;
}
template<typename ValueType>
inline Expr MakeConstScalar(Type t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t;
return Expr();
}
template<typename ValueType, typename>
inline Expr make_const(Type t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
return ir::Broadcast::make(
MakeConstScalar(t.element_of(), value), t.lanes());
}
}
inline Expr make_zero(Type t) {
if (t.is_handle()) {
return reinterpret(t, make_const(UInt(64), 0));
}
return make_const(t, 0);
}
// additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \
a = OpFunc(a, b); \
return a; \
}
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, float b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(float a, const Expr& b) { \
return Name(Expr(a), b); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
} \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
}
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, bool b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(bool a, const Expr& b) { \
return Name(Expr(a), b); \
}
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
}
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
// logical ops
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
} // namespace tvm
#endif // TVM_IR_OPERATOR_H_