blob: 626d0b9cbab59814d63e019f71aaee236ed1dc0d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arithmetic/pattern_match.h
*
* \brief Internal tool for expression-template based pattern matching.
*
* It helps to simplify pattern matching and rewrites.
* All the patterns are generated via expression template during compile time,
* so the result code should be as efficient as manually written pattern match code.
*
* The code below shows how to use the pattern matcher.
*
* \code
*
* // max(x + z, y + z) => max(x, y) + z
* arith::PVar<Expr> x, y, z;
*
* // The following code tries to match the declared pattern.
* // Match will fill the result of match into PVar if successful.
* // Note that z occurs twice in the pattern,
* // an equality check is performed to ensure each occurance of z
* // is equivalent to each other.
* if (max(x + z, y + z).Match(expr)) {
* // Eval evaluates a pattern with the current matched value.
* // The filled value is valid until the next call to Match.
* return (max(x, y) + z).Eval();
* }
*
* tvm::tir::Var tx, ty;
* arith::PVar<IntImm> c;
* arith::PVar<Var> v;
* // We can match integer and Var, both of which are
* // special case container of Expr
* TVM_FFI_ICHECK((v * c).Match(tx * 3));
* TVM_FFI_ICHECK_EQ(c.Eval()->value, 3);
* // cannot match c to ty
* TVM_FFI_ICHECK(!(v * c).Match(tx * ty));
*
* \endcode
*
* \note The pattern matcher is not threadsafe,
* do not use the same PVar in multiple threads.
*
* Please be aware that the filled value in a PVar
* can be overriden in the next call to Match.
*/
#ifndef TVM_ARITH_PATTERN_MATCH_H_
#define TVM_ARITH_PATTERN_MATCH_H_
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <cmath>
#include <tuple>
#include "const_fold.h"
namespace tvm {
namespace arith {
/*!
* \brief Base class of all the patterns.
*
* There are two major member functions supported by each pattern.
* - Match: checks if value matches the pattern.
* - Eval: construct a new value based on matched values in PVar.
*
* We use curiously recurring template pattern to construct
* expression templates.
*
* \tparam Derived The type of the derived class.
*/
template <typename Derived>
class Pattern {
public:
/*!
* \brief Nested storage type in the expression.
*
* Depending on the Derived class,
* Nested can be Derived (nest by value) or
* const Derived& (nest by reference).
*
* The trick of Nested typedef originates from Eigen.
*
* \note We use nest by value for intermediate expressions,
* and nest by reference for PVars.
*/
using Nested = Derived;
/*!
* \brief Check if value matches the current pattern.
*
* This call also populates the PVars with matched value.
* The values in PVars are valid until the next call to Match.
*
* \param value The value to be matched against
*
* \return whether value matches the pattern.
*/
template <typename NodeType>
inline bool Match(const NodeType& value) const {
return Match(value, []() { return true; });
}
/*!
* \brief Check if value matches the current pattern.
*
* This call also populates the PVars with matched value.
* The values in PVars are valid until the next call to Match.
*
* \param value The value to be matched against
*
* \param cond A callable that performs additional validation,
* returning true if the match passes. This will typically be a
* lambda function written in terms of the filled PVars.
*
* \return whether value matches the pattern.
*/
template <typename NodeType, typename Condition>
bool Match(const NodeType& value, Condition cond) const {
derived().InitMatch_();
return derived().Match_(value) && cond();
}
/*! \return Derived instance of current class. */
const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
/*!
* \brief Default deep equality checker
* \tparam T the comparison point.
*/
template <typename T>
class PEqualChecker {
public:
bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; }
};
template <>
class PEqualChecker<PrimExpr> {
public:
bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs.same_as(rhs)) return true;
return tir::ExprDeepEqual()(lhs, rhs);
}
};
template <>
class PEqualChecker<IntImm> {
public:
bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; }
};
template <>
class PEqualChecker<FloatImm> {
public:
bool operator()(const FloatImm& lhs, const FloatImm& rhs) const {
return std::fabs(lhs->value - rhs->value) < 1e-20;
}
};
template <>
class PEqualChecker<tir::Var> {
public:
bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); }
};
/*!
* \brief Pattern variable container.
*
* PVar is used as a "hole" in the pattern that can be matched.
*
* \tparam T the type of the hole.
*
* \note PVar is not thread safe.
* Do not use the same PVar in multiple threads.
*/
template <typename T>
class PVar : public Pattern<PVar<T>> {
public:
// Store PVars by reference in the expression.
using Nested = const PVar<T>&;
void InitMatch_() const { filled_ = false; }
bool Match_(const T& value) const {
if (!filled_) {
value_ = value;
filled_ = true;
return true;
} else {
return PEqualChecker<T>()(value_, value);
}
}
template <typename NodeRefType,
typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type>
bool Match_(const NodeRefType& value) const {
if (const auto* ptr = value.template as<typename T::ContainerType>()) {
return Match_(ffi::GetRef<T>(ptr));
} else {
return false;
}
}
T Eval() const {
TVM_FFI_ICHECK(filled_);
return value_;
}
T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; }
protected:
/*! \brief The matched value */
mutable T value_;
/*! \brief whether the variable has been filled */
mutable bool filled_{false};
};
/*!
* \brief Wrapper for pattern variable container with extra match logic.
*
* \tparam Derived the type of derived class.
* \tparam T the type of the hole.
*/
template <typename Derived, typename T>
class PVarWithCheck : public arith::Pattern<PVarWithCheck<Derived, T>> {
public:
// Store by reference in the expression.
using Nested = const PVarWithCheck<Derived, T>&;
void InitMatch_() const { pvar_.InitMatch_(); }
bool Match_(const T& value) const {
if (!static_cast<const Derived*>(this)->Match_(value)) return false;
return pvar_.Match_(value);
}
template <typename NodeRefType,
typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type>
bool Match_(const NodeRefType& value) const {
if (const auto* ptr = value.template as<typename T::ContainerType>()) {
return Match_(ffi::GetRef<T>(ptr));
} else {
return false;
}
}
T Eval() const { return pvar_.Eval(); }
protected:
arith::PVar<T> pvar_;
};
/*!
* \brief Pattern variable container with expr type check.
*
* \tparam T the type of the hole.
* \tparam DType the Pattern type of dtype.
*/
template <typename T, typename DType,
typename = std::enable_if<std::is_base_of<T, PrimExpr>::value>>
class PVarWithDataType : public PVarWithCheck<PVarWithDataType<T, DType>, T> {
public:
explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {}
bool Match_(const T& value) const { return dtype_.Match_(value->dtype); }
protected:
typename DType::Nested dtype_;
};
/*!
* \brief Pattern variable container for data type with lanes.
*/
class PVecDataType : public PVarWithCheck<PVecDataType, DataType> {
public:
/*! \brief construct vector dtype placeholder with element type check */
explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {}
bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); }
protected:
DataType elem_dtype_;
};
/*!
* \brief Constant Pattern variable container.
*
* \tparam T the type of the hole.
*/
template <typename T>
class PConst : public Pattern<PConst<T>> {
public:
PConst(T value) // NOLINT(*)
: value_(value) {}
void InitMatch_() const {}
bool Match_(const T& value) const { return PEqualChecker<T>()(value_, value); }
T Eval() const { return value_; }
private:
const T value_;
};
/*!
* \brief Pattern binary expression.
* \tparam OpType The AST noderef type.
* \tparam TA The pattern type of the first operand.
* \tparam TB The pattern type of the second operand.
*/
template <typename OpType, typename TA, typename TB>
class PBinaryExpr : public Pattern<PBinaryExpr<OpType, TA, TB>> {
public:
PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}
void InitMatch_() const {
a_.InitMatch_();
b_.InitMatch_();
}
bool Match_(const ObjectRef& node) const {
using NodeType = typename OpType::ContainerType;
if (const NodeType* ptr = node.as<NodeType>()) {
if (!a_.Match_(ptr->a)) return false;
if (!b_.Match_(ptr->b)) return false;
return true;
} else {
return false;
}
}
PrimExpr Eval() const {
PrimExpr lhs = a_.Eval();
PrimExpr rhs = b_.Eval();
if (auto ret = TryConstFold<OpType>(lhs, rhs)) return ret.value();
return OpType(lhs, rhs);
}
private:
typename TA::Nested a_;
typename TB::Nested b_;
};
template <typename TA>
class PConstWithTypeLike : public Pattern<PConstWithTypeLike<TA>> {
public:
PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {}
void InitMatch_() const {}
bool Match_(const ObjectRef& node) const {
if (const tir::IntImmNode* ptr = node.as<tir::IntImmNode>()) {
return ptr->value == value_;
} else {
return false;
}
}
PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); }
private:
typename TA::Nested ref_;
int64_t value_;
};
#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
template <typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
CheckStep; \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template <typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA>> FuncName(const Pattern<TA>& a, \
int64_t b) { \
CheckStep; \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template <typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> FuncName(int64_t b, \
const Pattern<TA>& a) { \
CheckStep; \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
// raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, tir::Div, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, tir::Mod, DivAmbiguityError(a));
// arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, tir::Add);
TVM_PATTERN_BINARY_OP(operator-, tir::Sub);
TVM_PATTERN_BINARY_OP(operator*, tir::Mul);
TVM_PATTERN_BINARY_OP(min, tir::Min);
TVM_PATTERN_BINARY_OP(max, tir::Max);
TVM_PATTERN_BINARY_OP(div, tir::Div);
TVM_PATTERN_BINARY_OP(truncdiv, tir::Div);
TVM_PATTERN_BINARY_OP(truncmod, tir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, tir::FloorMod);
// logical expressions
TVM_PATTERN_BINARY_OP(operator>, tir::GT);
TVM_PATTERN_BINARY_OP(operator>=, tir::GE);
TVM_PATTERN_BINARY_OP(operator<, tir::LT);
TVM_PATTERN_BINARY_OP(operator<=, tir::LE);
TVM_PATTERN_BINARY_OP(operator==, tir::EQ);
TVM_PATTERN_BINARY_OP(operator!=, tir::NE);
TVM_PATTERN_BINARY_OP(operator&&, tir::And);
TVM_PATTERN_BINARY_OP(operator||, tir::Or);
/*!
* \brief Pattern not expression.
* \tparam TA The pattern type of the true operand.
*/
template <typename TA>
class PNotExpr : public Pattern<PNotExpr<TA>> {
public:
explicit PNotExpr(const TA& value) : value_(value) {}
void InitMatch_() const { value_.InitMatch_(); }
bool Match_(const ObjectRef& node) const {
if (const tir::NotNode* ptr = node.as<tir::NotNode>()) {
if (!value_.Match_(ptr->a)) return false;
return true;
} else {
return false;
}
}
PrimExpr Eval() const { return tir::Not(value_.Eval()); }
private:
typename TA::Nested value_;
};
template <typename TA>
inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
return PNotExpr<TA>(value.derived());
}
// select
/*!
* \brief Pattern select expression.
* \tparam TCond The pattern type of the condition.
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
template <typename TCond, typename TA, typename TB>
class PSelectExpr : public Pattern<PSelectExpr<TCond, TA, TB>> {
public:
PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value)
: condition_(condition), true_value_(true_value), false_value_(false_value) {}
void InitMatch_() const {
condition_.InitMatch_();
true_value_.InitMatch_();
false_value_.InitMatch_();
}
bool Match_(const ObjectRef& node) const {
if (const tir::SelectNode* ptr = node.as<tir::SelectNode>()) {
if (!condition_.Match_(ptr->condition)) return false;
if (!true_value_.Match_(ptr->true_value)) return false;
if (!false_value_.Match_(ptr->false_value)) return false;
return true;
} else {
return false;
}
}
PrimExpr Eval() const {
return tir::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval());
}
private:
typename TCond::Nested condition_;
typename TA::Nested true_value_;
typename TB::Nested false_value_;
};
/*!
* \brief Construct a select pattern.
*
* \param condition The condition expression.
* \param true_value The value when condition is true.
* \param true_value The value when condition is false.
*
* \return The result pattern.
*
* \tparam TCond The pattern type of the condition.
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
template <typename TCond, typename TA, typename TB>
inline PSelectExpr<TCond, TA, TB> select(const Pattern<TCond>& condition,
const Pattern<TA>& true_value,
const Pattern<TB>& false_value) {
return PSelectExpr<TCond, TA, TB>(condition.derived(), true_value.derived(),
false_value.derived());
}
/*!
* \brief Pattern cast expression.
* \tparam DType The Pattern type of dtype.
* \tparam TA The pattern type of the first operand.
*/
template <typename DType, typename TA>
class PCastExpr : public Pattern<PCastExpr<DType, TA>> {
public:
PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {}
void InitMatch_() const {
dtype_.InitMatch_();
value_.InitMatch_();
}
bool Match_(const ObjectRef& node) const {
if (const tir::CastNode* ptr = node.as<tir::CastNode>()) {
if (!dtype_.Match_(ptr->dtype)) return false;
if (!value_.Match_(ptr->value)) return false;
return true;
} else {
return false;
}
}
PrimExpr Eval() const { return tir::Cast(dtype_.Eval(), value_.Eval()); }
private:
typename DType::Nested dtype_;
typename TA::Nested value_;
};
/*!
* \brief Construct a cast pattern.
*
* \param dtype The target data type, can be PVar<DataType> or PConst<DataType>.
* \param value The input type.
*
* \return The result pattern.
*
* \tparam DType The pattern type of type.
* \tparam TA The pattern type of value.
*/
template <typename DType, typename TA>
inline PCastExpr<DType, TA> cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
return PCastExpr<DType, TA>(dtype.derived(), value.derived());
}
/*!
* \brief Pattern ramp expression.
* \tparam TBase The pattern type of the base.
* \tparam TStride The pattern type of the stride.
* \tparam TLanes The pattern type of the lanes.
*/
template <typename TBase, typename TStride, typename TLanes>
class PRampExpr : public Pattern<PRampExpr<TBase, TStride, TLanes>> {
public:
PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes)
: base_(base), stride_(stride), lanes_(lanes) {}
void InitMatch_() const {
base_.InitMatch_();
stride_.InitMatch_();
lanes_.InitMatch_();
}
bool Match_(const ObjectRef& node) const {
if (const tir::RampNode* ptr = node.as<tir::RampNode>()) {
if (!base_.Match_(ptr->base)) return false;
if (!stride_.Match_(ptr->stride)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
return true;
} else {
return false;
}
}
PrimExpr Eval() const { return tir::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); }
private:
typename TBase::Nested base_;
typename TStride::Nested stride_;
typename TLanes::Nested lanes_;
};
/*!
* \brief Construct a ramp pattern.
*
* \param base The base pattern.
* \param stride The stride pattern.
* \param lanes The lanes pattern.
*
* \return The result pattern.
*
* \tparam TBase The pattern type of the base.
* \tparam TStride The pattern type of the stride.
* \tparam TLanes The pattern type of the lanes.
*/
template <typename TBase, typename TStride, typename TLanes>
inline PRampExpr<TBase, TStride, TLanes> ramp(const Pattern<TBase>& base,
const Pattern<TStride>& stride,
const Pattern<TLanes>& lanes) {
return PRampExpr<TBase, TStride, TLanes>(base.derived(), stride.derived(), lanes.derived());
}
template <typename TBase>
inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConstWithTypeLike<TBase>> ramp(
const Pattern<TBase>& base, int stride, int lanes) {
return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConstWithTypeLike<TBase>>(
base.derived(), PConstWithTypeLike<TBase>(base.derived(), stride),
PConstWithTypeLike<TBase>(base.derived(), lanes));
}
/*!
* \brief Pattern broadcast expression.
* \tparam TA The pattern type of the value.
* \tparam TLanes The pattern type of the lanes.
*/
template <typename TA, typename TLanes>
class PBroadcastExpr : public Pattern<PBroadcastExpr<TA, TLanes>> {
public:
PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {}
void InitMatch_() const {
value_.InitMatch_();
lanes_.InitMatch_();
}
bool Match_(const ObjectRef& node) const {
if (const tir::BroadcastNode* ptr = node.as<tir::BroadcastNode>()) {
if (!value_.Match_(ptr->value)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
return true;
} else {
return false;
}
}
PrimExpr Eval() const { return tir::Broadcast(value_.Eval(), lanes_.Eval()); }
private:
typename TA::Nested value_;
typename TLanes::Nested lanes_;
};
/*!
* \brief Construct a broadcast pattern.
*
* \param value The value pattern.
* \param lanes The lanes pattern.
*
* \return The result pattern.
*
* \tparam TA The pattern type of the value.
* \tparam TLanes The pattern type of the lanes.
*/
template <typename TA, typename TLanes>
inline PBroadcastExpr<TA, TLanes> broadcast(const Pattern<TA>& value,
const Pattern<TLanes>& lanes) {
return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
}
// internal namespace
namespace detail {
// implementation details for CallExpr
template <bool stop, std::size_t I, typename F>
struct tuple_for_each_dispatcher {
template <typename TTuple>
static void run(F& f, const TTuple& tuple) { // NOLINT(*)
f(I, std::get<I>(tuple));
tuple_for_each_dispatcher<(I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>::run(f, tuple);
}
};
template <std::size_t I, typename F>
struct tuple_for_each_dispatcher<true, I, F> {
template <typename TTuple>
static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
};
template <typename F, typename TTuple>
inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*)
tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>::run(f, tuple);
}
struct PCallExprInitMatchFunctor {
template <typename T>
void operator()(size_t i, const T& pattern) const {
pattern.InitMatch_();
}
};
struct PCallExprMatchFunctor {
const tir::CallNode* call_;
bool matched_{true};
explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {}
template <typename T>
void operator()(size_t i, const T& pattern) {
matched_ = matched_ && pattern.Match_(call_->args[i]);
}
};
struct PCallExprEvalArgsFunctor {
ffi::Array<PrimExpr> args_;
template <typename T>
void operator()(size_t i, const T& pattern) {
args_.push_back(pattern.Eval());
}
};
} // namespace detail
/*!
* \brief Pattern CallExpr expression.
* \tparam Op The operator functor class.
* \tparam TArgs The arguments.
* \note Op functor contains the name of the function and
* the implementation of Eval.
*/
template <typename Op, typename... TArgs>
class PCallExpr : public Pattern<PCallExpr<Op, TArgs...>> {
public:
explicit PCallExpr(const TArgs&... args) : args_(args...) {}
void InitMatch_() const {
detail::PCallExprInitMatchFunctor finit;
detail::tuple_for_each(finit, args_);
}
bool Match_(const ObjectRef& node) const {
if (const tir::CallNode* ptr = node.as<tir::CallNode>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false;
if (!ptr->op.same_as(Op::GetOp())) return false;
detail::PCallExprMatchFunctor fmatch(ptr);
detail::tuple_for_each(fmatch, args_);
return fmatch.matched_;
} else {
return false;
}
}
PrimExpr Eval() const {
detail::PCallExprEvalArgsFunctor feval_args;
detail::tuple_for_each(feval_args, args_);
return Op::Eval(feval_args.args_);
}
private:
std::tuple<typename TArgs::Nested...> args_;
};
// arithemetic intrinsics
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \
struct OpName { \
static PrimExpr Eval(ffi::Array<PrimExpr> args) { \
return tir::Call(args[0].dtype(), GetOp(), args); \
} \
static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
}; \
template <typename TA, typename TB> \
inline PCallExpr<OpName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
}
TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left);
TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right);
TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and);
TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or);
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor);
// unary intrinsics
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \
struct OpName { \
static PrimExpr Eval(ffi::Array<PrimExpr> args) { \
return tir::Call(args[0].dtype(), GetOp(), args); \
} \
static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
}; \
template <typename TA> \
inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \
return PCallExpr<OpName, TA>(a.derived()); \
}
TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not);
// if_then_else
struct PIfThenElseOp {
static PrimExpr Eval(ffi::Array<PrimExpr> args) {
return tir::Call(args[1].dtype(), GetOp(), args);
}
static const Op& GetOp() { return tir::builtin::if_then_else(); }
};
/*!
* \brief Construct a if_then_else pattern.
*
* \param cond The condition expression.
* \param true_value The value when condition is true.
* \param true_value The value when condition is false.
*
* \return The result pattern.
*
* \tparam TCond The pattern type of the condition.
* \tparam TA The pattern type of the true operand.
* \tparam TB The pattern type of the false operand.
*/
template <typename TCond, typename TA, typename TB>
inline PCallExpr<PIfThenElseOp, TCond, TA, TB> if_then_else(const Pattern<TCond>& cond,
const Pattern<TA>& true_value,
const Pattern<TB>& false_value) {
return PCallExpr<PIfThenElseOp, TCond, TA, TB>(cond.derived(), true_value.derived(),
false_value.derived());
}
// vscale
struct PVscaleOp {
static PrimExpr Eval() { return tir::Call(DataType::Int(32), GetOp(), {}); }
static const Op& GetOp() { return tir::builtin::vscale(); }
};
template <typename... TPattern>
class PMatchesOneOf {
public:
explicit PMatchesOneOf(const TPattern&... patterns) : patterns_{patterns...} {}
/*! \brief Check if value matches one of the patterns.
*
* This call also populates the PVars with matched value based on
* the first successful match. The values in PVars are valid until
* the next call to Match.
*
* \param value The value to be matched against.
*
* \return Whether value matches the pattern.
*/
template <typename NodeType>
inline bool Match(const NodeType& value) const {
return Match(value, []() { return true; });
}
/*! \brief Check if value matches one of the patterns.
*
* This call also populates the PVars with matched value based on
* the first successful match. The values in PVars are valid until
* the next call to Match.
*
* \param value The value to be matched against.
*
* \param cond A callable that performs additional validation,
* returning true if the match passes. This will typically be a
* lambda function written in terms of the filled PVars. This will
* be called once for each successful pattern match. If `cond()`
* returns false, the next match will be attempted.
*
* \return Whether value matches the pattern.
*/
template <typename NodeType, typename Condition>
inline bool Match(const NodeType& value, Condition cond) const {
return MatchImpl(value, cond, std::make_index_sequence<sizeof...(TPattern)>());
}
private:
template <typename NodeType, typename Condition>
inline bool MatchImpl(const NodeType& value, Condition cond, std::index_sequence<>) const {
return false;
}
template <typename NodeType, typename Condition, size_t FirstIndex, size_t... RemainingIndices>
inline bool MatchImpl(const NodeType& value, Condition cond,
std::index_sequence<FirstIndex, RemainingIndices...>) const {
return std::get<FirstIndex>(patterns_).Match(value, cond) ||
MatchImpl(value, cond, std::index_sequence<RemainingIndices...>());
}
// Hold the patterns by const&. This follows the same usage as both
// the `PVar`, which occurs as `const PVar<T>&` when it appears
// inside other patterns. Because the `PVar<T>::value_` field is
// mutable, it can still be updated through these const references.
// So long as the call to `Match()` occurs within the same
// expression as created the patterns, this avoids accidental copies
// without creating dangling references. This may be improved in
// the future by use of `constexpr` constructors/operators, allowing
// more typical value semantics.
std::tuple<const TPattern&...> patterns_;
};
/* \brief Return a proxy object that returns true after the first match
*
* In the RewriteSimplifier, there are often several expressions that
* simplify to the same resulting expression. This utility allows
* them to be specified as a single rule, reducing duplication of the
* result/condition of a rewrite.
*/
template <typename... TPattern>
inline std::enable_if_t<(std::is_base_of_v<Pattern<TPattern>, TPattern> && ... && true),
PMatchesOneOf<TPattern...>>
matches_one_of(const TPattern&... patterns) {
return PMatchesOneOf<TPattern...>(patterns...);
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_PATTERN_MATCH_H_