[Relax][Transform] Add SelectNode handling in SymbolicMatcher (#17368)
This PR added support for handling SelectNode in the SymbolicMatcher
class by modifying the VisitExpr_ function to match the true_value
and false_value expressions between the current SelectNode and the
other expression. If the other expression is not a SelectNode, the
matching condition is updated to ensure the current SelectNode
expression is equivalent to the other expression.
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 612e145..fe24764 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -139,6 +139,16 @@
}
}
+ void VisitExpr_(const SelectNode* op, const PrimExpr& other) {
+ const auto* rhs = other.as<SelectNode>();
+ if (rhs) {
+ VisitExpr(op->true_value, rhs->true_value);
+ VisitExpr(op->false_value, rhs->false_value);
+ } else {
+ must_prove_ = must_prove_ && (GetRef<PrimExpr>(op) == other);
+ }
+ }
+
arith::Analyzer* analyzer_;
Map<tir::Var, PrimExpr>* var_remap_;
PrimExpr must_prove_ = Bool(true);