| /*! |
| * Copyright (c) 2018 by Contributors |
| * \file tvm/ir_operator.h |
| * \brief Common operators defined for Expr. |
| * |
| * \note Most of the operator defined here perform simple constant folding |
| * when the type is int32 or int64 for simplifying the index expressions. |
| */ |
| #ifndef TVM_IR_OPERATOR_H_ |
| #define TVM_IR_OPERATOR_H_ |
| |
| #include <algorithm> |
| #include <type_traits> |
| #include "expr.h" |
| #include "ir.h" |
| |
| namespace tvm { |
| /*! |
| * \brief Make a const value with certain data type. |
| * \param t The target type. |
| * \param value The input value |
| * \return the result expression. |
| * \tparam ValueType The constant value type |
| */ |
| template<typename ValueType, |
| typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> |
| inline Expr make_const(Type t, ValueType value); |
| /*! |
| * \brief Make a const zero expr. |
| * \param t The target type. |
| * \return the result expression. |
| */ |
| inline Expr make_zero(Type t); |
| /*! |
| * \brief Make a constant true expression. |
| * \param lanes The number of lanes in the bool |
| * \return The result expression. |
| */ |
| inline Expr const_true(int lanes = 1) { |
| return make_const(UInt(1, lanes), 1); |
| } |
| /*! |
| * \brief Make a constant false expression. |
| * \param lanes The number of lanes in the bool |
| * \return The result expression. |
| */ |
| inline Expr const_false(int lanes = 1) { |
| return make_const(UInt(1, lanes), 0); |
| } |
| /*! |
| * \brief Get x as constant int expression. |
| * \param x The expression |
| * \return the address to the int expression, |
| * return nullptr, if x is not IntImm. |
| */ |
| inline const int64_t* as_const_int(const Expr& x) { |
| if (!x.defined()) return nullptr; |
| if (const ir::IntImm* op = x.as<ir::IntImm>()) { |
| return &(op->value); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| /*! |
| * \brief Get x as constant uint expression. |
| * \param x The expression |
| * \return the address to the int expression, |
| * return nullptr, if x is not UIntImm. |
| */ |
| inline const uint64_t* as_const_uint(const Expr& x) { |
| if (!x.defined()) return nullptr; |
| if (const ir::UIntImm* op = x.as<ir::UIntImm>()) { |
| return &(op->value); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| /*! |
| * \brief Check whether x is a constant integer expression. |
| * \param x The input argument |
| * \param value the value to be compared against. |
| * \return whether x is constant expression. |
| */ |
| inline bool is_const_int(const Expr& x, int64_t value); |
| |
| /*! |
| * \brief Check whether stmt is nop. |
| * \param stmt The input statement |
| * \return whether stmt is nop |
| */ |
| inline bool is_no_op(const Stmt& stmt); |
| |
| /*! |
| * \brief Check whether x is a constant integer 1 |
| * \param x The input argument. |
| * \note This only return true for integer types. |
| * \return whether x is constant 1 |
| */ |
| inline bool is_one(const Expr& x) { |
| return is_const_int(x, 1); |
| } |
| |
| /*! |
| * \brief Check whether x is a constant integer 0 |
| * \param x The input argument |
| * \return whether x is constant 0 |
| * \note This only return true for integer types. |
| */ |
| inline bool is_zero(const Expr& x) { |
| return is_const_int(x, 0); |
| } |
| |
| /*! |
| * \brief Check whether x is a constant. |
| * \note This only return true for integer types. |
| * \return whether x is constant |
| */ |
| inline bool is_const(const Expr& x); |
| |
| /*! |
| * \brief Check whether x is a constant power of two |
| * If x is power of two, write the power to the shift. |
| * |
| * \param x The input expression. |
| * \param shift The output shift if x is power of two. |
| * \return whether x is constant power of two |
| */ |
| TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift); |
| |
| /*! |
| * \brief cast value to type. |
| * |
| * \param t the target type. |
| * \param value The value |
| * \return The result expression. |
| * \note This function may return value if the type is the same. |
| */ |
| TVM_DLL Expr cast(const Type& t, Expr value); |
| /*! |
| * \brief perform reinterpret cast value to type. |
| * |
| * \param t the target type. |
| * \param value The value |
| * \return The result expression. |
| * \note This function may return value if the type is the same. |
| */ |
| TVM_DLL Expr reinterpret(const Type& t, Expr value); |
| /*! |
| * \brief add operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator+(Expr a, Expr b); |
| /*! |
| * \brief subtraction operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator-(Expr a, Expr b); |
| /*! |
| * \brief negation. |
| * |
| * \param a input. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator-(Expr a); |
| /*! |
| * \brief multiplication operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator*(Expr a, Expr b); |
| /*! |
| * \brief division operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator/(Expr a, Expr b); |
| /*! |
| * \brief mod operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator%(Expr a, Expr b); |
| /*! |
| * \brief left shift operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator<<(Expr a, Expr b); |
| /*! |
| * \brief right shift operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator>>(Expr a, Expr b); |
| /*! |
| * \brief greater |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator>(Expr a, Expr b); |
| /*! |
| * \brief greater_equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator>=(Expr a, Expr b); |
| /*! |
| * \brief less |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator<(Expr a, Expr b); |
| /*! |
| * \brief less_equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator<=(Expr a, Expr b); |
| /*! |
| * \brief equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator==(Expr a, Expr b); |
| /*! |
| * \brief not_equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator!=(Expr a, Expr b); |
| /*! |
| * \brief and |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note This operator does eager constant folding. |
| */ |
| TVM_DLL Expr operator&&(Expr a, Expr b); |
| /*! |
| * \brief or |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note This operator does eager constant folding. |
| */ |
| TVM_DLL Expr operator||(Expr a, Expr b); |
| /*! |
| * \brief not |
| * |
| * \param a left operand |
| * \return The result expression. |
| * \note This operator does eager constant folding. |
| */ |
| TVM_DLL Expr operator!(Expr a); |
| /*! |
| * \brief take maximum of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr max(Expr a, Expr b); |
| /*! |
| * \brief take minimum of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr min(Expr a, Expr b); |
| /*! |
| * \brief take bitwise and of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator&(Expr a, Expr b); |
| /*! |
| * \brief take bitwise or of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator|(Expr a, Expr b); |
| /*! |
| * \brief take bitwise xor of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator^(Expr a, Expr b); |
| /*! |
| * \brief take bitwise negation of two values |
| * |
| * \param a the input expression. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr operator~(Expr a); |
| /*! |
| * \brief Conditional expression. |
| * |
| * \param cond The condition |
| * \param true_value The value when results are true. |
| * \param false_value The value when results are false. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value); |
| /*! |
| * \brief Mark condition as likely. |
| * \param cond The condition |
| * \return The marked expression. |
| */ |
| TVM_DLL Expr likely(Expr cond); |
| /*! |
| * \brief Calculate power(x, y) |
| * \param x The left operand. |
| * \param y The right operand. |
| */ |
| TVM_DLL Expr pow(Expr x, Expr y); |
| /*! |
| * \brief Calculate absolute value of x. |
| * \param x The input data |
| * |
| * \return The aboslute value of input data x |
| */ |
| TVM_DLL Expr abs(Expr x); |
| |
| /*! |
| * \brief sum of of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| */ |
| TVM_DLL Expr sum(Expr source, Array<IterVar> axis); |
| |
| /*! |
| * \brief max of of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| */ |
| TVM_DLL Expr max(Expr source, Array<IterVar> axis); |
| |
| /*! |
| * \brief max of of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| */ |
| TVM_DLL Expr min(Expr source, Array<IterVar> axis); |
| |
| /*! |
| * \brief product of of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| */ |
| TVM_DLL Expr prod(Expr source, Array<IterVar> axis); |
| |
| // Intrinsic operators |
| #define TVM_DECLARE_INTRIN_UNARY(OpName) \ |
| inline Expr OpName(Expr x) { \ |
| return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \ |
| } \ |
| |
| TVM_DECLARE_INTRIN_UNARY(exp); |
| TVM_DECLARE_INTRIN_UNARY(tanh); |
| TVM_DECLARE_INTRIN_UNARY(sigmoid); |
| TVM_DECLARE_INTRIN_UNARY(sqrt); |
| TVM_DECLARE_INTRIN_UNARY(log); |
| TVM_DECLARE_INTRIN_UNARY(floor); |
| TVM_DECLARE_INTRIN_UNARY(ceil); |
| TVM_DECLARE_INTRIN_UNARY(round); |
| TVM_DECLARE_INTRIN_UNARY(trunc); |
| TVM_DECLARE_INTRIN_UNARY(popcount); |
| |
| |
| // Implementation details after this |
| inline bool is_const(const Expr& x) { |
| if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) { |
| return true; |
| } else if (const auto* op = x.as<ir::Broadcast>()) { |
| const Expr& val = op->value; |
| if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| inline bool is_positive_const(const Expr& a) { |
| if (const ir::IntImm* op = a.as<ir::IntImm>()) { |
| return op->value > 0; |
| } else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) { |
| return op->value > 0; |
| } else { |
| return false; |
| } |
| } |
| |
| inline bool is_negative_const(const Expr& a) { |
| if (const ir::IntImm* op = a.as<ir::IntImm>()) { |
| return op->value < 0; |
| } else { |
| return false; |
| } |
| } |
| |
| inline bool is_const_int(const Expr& x, int64_t value) { |
| if (const auto* op = x.as<ir::IntImm>()) { |
| return op->value == value; |
| } else if (const auto* op = x.as<ir::UIntImm>()) { |
| return op->value == static_cast<uint64_t>(value); |
| } else if (const auto* op = x.as<ir::Broadcast>()) { |
| const Expr& val = op->value; |
| if (const auto* opv = val.as<ir::IntImm>()) { |
| return opv->value == value; |
| } else if (const auto* opv = val.as<ir::UIntImm>()) { |
| return opv->value == static_cast<uint64_t>(value); |
| } |
| } |
| return false; |
| } |
| |
| inline bool is_no_op(const Stmt& stmt) { |
| if (!stmt.defined()) return true; |
| if (const auto* op = stmt.as<ir::Evaluate>()) { |
| return is_const(op->value); |
| } |
| return false; |
| } |
| |
| template<typename ValueType> |
| inline Expr MakeConstScalar(Type t, ValueType value) { |
| if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value)); |
| if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value)); |
| if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value)); |
| LOG(FATAL) << "cannot make const for type " << t; |
| return Expr(); |
| } |
| |
| template<typename ValueType, typename> |
| inline Expr make_const(Type t, ValueType value) { |
| if (t.lanes() == 1) { |
| return MakeConstScalar(t, value); |
| } else { |
| return ir::Broadcast::make( |
| MakeConstScalar(t.element_of(), value), t.lanes()); |
| } |
| } |
| |
| inline Expr make_zero(Type t) { |
| if (t.is_handle()) { |
| return reinterpret(t, make_const(UInt(64), 0)); |
| } |
| return make_const(t, 0); |
| } |
| |
| // additional const expression overloading |
| #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ |
| inline Expr Name(Expr& a, Expr b) { \ |
| a = OpFunc(a, b); \ |
| return a; \ |
| } |
| |
| #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ |
| inline Expr Name(const Expr& a, float b) { \ |
| return Name(a, Expr(b)); \ |
| } \ |
| inline Expr Name(float a, const Expr& b) { \ |
| return Name(Expr(a), b); \ |
| } \ |
| inline Expr Name(int a, const Expr& b) { \ |
| return Name(make_const(b.type(), a), b); \ |
| } \ |
| inline Expr Name(const Expr& a, int b) { \ |
| return Name(a, make_const(a.type(), b)); \ |
| } |
| |
| #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ |
| inline Expr Name(const Expr& a, bool b) { \ |
| return Name(a, Expr(b)); \ |
| } \ |
| inline Expr Name(bool a, const Expr& b) { \ |
| return Name(Expr(a), b); \ |
| } |
| |
| #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ |
| inline Expr Name(const Expr& a, int b) { \ |
| return Name(a, make_const(a.type(), b)); \ |
| } \ |
| inline Expr Name(int a, const Expr& b) { \ |
| return Name(make_const(b.type(), a), b); \ |
| } |
| |
| |
| TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); |
| TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); |
| TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); |
| TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*) |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); |
| // integer related ops |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); |
| // logical ops |
| TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); |
| TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); |
| |
| } // namespace tvm |
| #endif // TVM_IR_OPERATOR_H_ |