blob: 218e9d218a66ec7acbcc2c6dc83d2c6d0288a5c0 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <arithmetic/Interval.h>
#include <limits>
namespace tvm {
namespace arith {
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \tparam Op the computation operator
* \return The result.
*/
template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
/*!
* \brief Compute an reduction with Op
* \param values The input values.
* \param empty_value The value when return if it is empty, can be Expr()
* which will cause an error to be rasied.
* \tparam Op The computation operator
* \return The result.
*/
template<typename Op>
inline Expr ComputeReduce(
const Array<Expr>& values, Expr empty_value);
inline bool GetConst(Expr e, int64_t* out) {
if (e.type().is_vector()) return false;
const int64_t* v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
return false;
}
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
return a + b;
}
template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
return a - b;
}
template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
return a * b;
}
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
return a / b;
}
template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
return a % b;
}
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return HalideIR::Internal::Interval::make_max(a, b);
}
template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return HalideIR::Internal::Interval::make_min(a, b);
}
template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
if (values.size() == 0U) {
CHECK(empty_value.defined());
return empty_value;
}
Expr res = values[0];
for (size_t i = 1; i < values.size(); ++i) {
res = ComputeExpr<Op>(res, values[i]);
}
return res;
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_