blob: 4c5ea5bfd2d0e932d5614509c10f458f5effca4c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file expr_functor.cc
*/
#include <tvm/tir/expr_functor.h>
#include "functor_common.h"
namespace tvm {
namespace tir {
void ExprVisitor::VisitExpr_(const VarNode* op) {}
void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
this->VisitExpr_(static_cast<const VarNode*>(op));
}
void ExprVisitor::VisitExpr_(const AnyNode* op) {}
void ExprVisitor::VisitExpr_(const LoadNode* op) {
this->VisitExpr(op->index);
this->VisitExpr(op->predicate);
}
void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
void ExprVisitor::VisitExpr_(const LetNode* op) {
this->VisitExpr(op->value);
this->VisitExpr(op->body);
}
void ExprVisitor::VisitExpr_(const CallNode* op) {
VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
#define DEFINE_BINOP_VISIT_(OP) \
void ExprVisitor::VisitExpr_(const OP* op) { \
this->VisitExpr(op->a); \
this->VisitExpr(op->b); \
}
DEFINE_BINOP_VISIT_(AddNode);
DEFINE_BINOP_VISIT_(SubNode);
DEFINE_BINOP_VISIT_(MulNode);
DEFINE_BINOP_VISIT_(DivNode);
DEFINE_BINOP_VISIT_(ModNode);
DEFINE_BINOP_VISIT_(FloorDivNode);
DEFINE_BINOP_VISIT_(FloorModNode);
DEFINE_BINOP_VISIT_(MinNode);
DEFINE_BINOP_VISIT_(MaxNode);
DEFINE_BINOP_VISIT_(EQNode);
DEFINE_BINOP_VISIT_(NENode);
DEFINE_BINOP_VISIT_(LTNode);
DEFINE_BINOP_VISIT_(LENode);
DEFINE_BINOP_VISIT_(GTNode);
DEFINE_BINOP_VISIT_(GENode);
DEFINE_BINOP_VISIT_(AndNode);
DEFINE_BINOP_VISIT_(OrNode);
void ExprVisitor::VisitExpr_(const IntImmNode* op) {}
void ExprVisitor::VisitExpr_(const FloatImmNode* op) {}
void ExprVisitor::VisitExpr_(const StringImmNode* op) {}
void ExprVisitor::VisitExpr_(const ReduceNode* op) {
VisitArray(op->axis, [this](const IterVar& r) {
this->VisitExpr(r->dom->min);
this->VisitExpr(r->dom->extent);
});
VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); });
if (!op->init.empty()) {
VisitArray(op->init, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
this->VisitExpr(op->condition);
}
void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); }
void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); }
void ExprVisitor::VisitExpr_(const SelectNode* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->true_value);
this->VisitExpr(op->false_value);
}
void ExprVisitor::VisitExpr_(const RampNode* op) {
this->VisitExpr(op->base);
this->VisitExpr(op->stride);
}
void ExprVisitor::VisitExpr_(const ShuffleNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); }
PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef<PrimExpr>(op); }
PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
return this->VisitExpr_(static_cast<const VarNode*>(op));
}
PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef<PrimExpr>(op); }
PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
PrimExpr index = this->VisitExpr(op->index);
PrimExpr predicate = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && predicate.same_as(op->predicate)) {
return GetRef<PrimExpr>(op);
} else {
return Load(op->dtype, op->buffer_var, index, predicate);
}
}
PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
return BufferLoad(op->buffer, indices);
}
}
PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
return ProducerLoad(op->producer, indices);
}
}
PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
}
PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> args = MutateArray(op->args, fmutate);
if (args.same_as(op->args)) {
return GetRef<PrimExpr>(op);
} else {
return Call(op->dtype, op->op, args);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef<PrimExpr>(op); }
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
PrimExpr ExprMutator::VisitExpr_(const OP##Node* op) { \
PrimExpr a = this->VisitExpr(op->a); \
PrimExpr b = this->VisitExpr(op->b); \
if (a.same_as(op->a) && b.same_as(op->b)) { \
return GetRef<PrimExpr>(op); \
} else { \
return OP(a, b); \
} \
}
DEFINE_BIOP_EXPR_MUTATE_(Add);
DEFINE_BIOP_EXPR_MUTATE_(Sub);
DEFINE_BIOP_EXPR_MUTATE_(Mul);
DEFINE_BIOP_EXPR_MUTATE_(Div);
DEFINE_BIOP_EXPR_MUTATE_(Mod);
DEFINE_BIOP_EXPR_MUTATE_(FloorDiv);
DEFINE_BIOP_EXPR_MUTATE_(FloorMod);
DEFINE_BIOP_EXPR_MUTATE_(Min);
DEFINE_BIOP_EXPR_MUTATE_(Max);
DEFINE_BIOP_EXPR_MUTATE_(EQ);
DEFINE_BIOP_EXPR_MUTATE_(NE);
DEFINE_BIOP_EXPR_MUTATE_(LT);
DEFINE_BIOP_EXPR_MUTATE_(LE);
DEFINE_BIOP_EXPR_MUTATE_(GT);
DEFINE_BIOP_EXPR_MUTATE_(GE);
DEFINE_BIOP_EXPR_MUTATE_(And);
DEFINE_BIOP_EXPR_MUTATE_(Or);
PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
auto fitervar = [this](const IterVar& v) {
Range r = v->dom;
PrimExpr min = this->VisitExpr(r->min);
PrimExpr extent = this->VisitExpr(r->extent);
if (min.same_as(r->min) && extent.same_as(r->extent)) {
return v;
} else {
return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag);
}
};
Array<IterVar> axis = MutateArray(op->axis, fitervar);
auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> source = MutateArray(op->source, fexpr);
Array<PrimExpr> init = MutateArray(op->init, fexpr);
PrimExpr condition = this->VisitExpr(op->condition);
if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) &&
init.same_as(op->init)) {
return GetRef<PrimExpr>(op);
} else {
return Reduce(op->combiner, source, axis, condition, op->value_index, init);
}
}
PrimExpr ExprMutator::VisitExpr_(const CastNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
return Cast(op->dtype, value);
}
}
PrimExpr ExprMutator::VisitExpr_(const NotNode* op) {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
} else {
return Not(a);
}
}
PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr true_value = this->VisitExpr(op->true_value);
PrimExpr false_value = this->VisitExpr(op->false_value);
if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
return Select(condition, true_value, false_value);
}
}
PrimExpr ExprMutator::VisitExpr_(const RampNode* op) {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
if (base.same_as(op->base) && stride.same_as(op->stride)) {
return GetRef<PrimExpr>(op);
} else {
return Ramp(base, stride, op->lanes);
}
}
PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
return Broadcast(value, op->lanes);
}
}
PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
auto vectors = MutateArray(op->vectors, fexpr);
if (vectors.same_as(op->vectors)) {
return GetRef<PrimExpr>(op);
} else {
return Shuffle(vectors, op->indices);
}
}
} // namespace tir
} // namespace tvm