blob: 1a1279f0640ad00881067ab161a3c5713aca60de [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \brief Lift specified AttrStmt scope to outer if
* the body contains the same scope.
* \file lift_attr_scope.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "ir_util.h"
namespace tvm {
namespace tir {
// NOTE: this optimization can only be applied
// to a few specified attr keys
class AttrScopeLifter : public StmtMutator {
public:
explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) {
stmt = operator()(std::move(stmt));
if (attr_node_.defined()) {
stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt);
}
return stmt;
}
// do not go beyond
Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
if (attr_node_.defined()) {
Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body);
} else {
return stmt;
}
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr_key_) {
attr_node_ = op->node;
attr_value_ = op->value;
return op->body;
} else {
return StmtMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
// remember the decorations.
std::vector<ObjectRef> attr_node;
std::vector<PrimExpr> attr_value;
auto fmutate = [&](const Stmt& s) {
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
Stmt ret = this->VisitStmt(s);
attr_node.push_back(attr_node_);
attr_value.push_back(attr_value_);
return ret;
};
Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate);
if (attr_node.size() == 0) return ret;
op = ret.as<SeqStmtNode>();
CHECK(op != nullptr);
Array<Stmt> reorg;
// check if all decorations are common.
for (size_t begin = 0; begin < attr_node.size();) {
size_t end = begin + 1;
while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) &&
ValueSame(attr_value[end], attr_value[begin])) {
++end;
}
// covers everything
// lift attr to parent.
if (begin == 0 && end == attr_node.size()) {
attr_node_ = attr_node[0];
attr_value_ = attr_value[0];
return ret;
}
// construct subsegments.
Array<Stmt> seq;
for (size_t i = begin; i < end; ++i) {
seq.push_back(op->seq[i]);
}
Stmt stmt = SeqStmt::Flatten(seq);
if (attr_node[begin].defined()) {
stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt);
}
reorg.push_back(stmt);
begin = end;
}
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
return SeqStmt::Flatten(reorg);
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
if (!op->else_case.defined()) {
return StmtMutator::VisitStmt_(op);
}
Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node;
PrimExpr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->VisitStmt(op->else_case);
if (attr_node_.defined() && attr_value_.defined() && first_node.defined() &&
first_value.defined() && attr_node_.same_as(first_node) &&
ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElse(op->condition, then_case, else_case);
}
} else {
if (first_node.defined()) {
then_case = AttrStmt(first_node, attr_key_, first_value, then_case);
}
if (attr_node_.defined()) {
else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
}
if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElse(op->condition, then_case, else_case);
}
}
}
private:
// value comparison that also compares content of int constant
static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
if (a.same_as(b)) return true;
if (!a.defined() || !b.defined()) return false;
if (a->type_index() != b->type_index()) return false;
if (a.dtype() != b.dtype()) return false;
if (const IntImmNode* op = a.as<IntImmNode>()) {
return op->value == b.as<IntImmNode>()->value;
}
return false;
}
std::string attr_key_;
ObjectRef attr_node_;
PrimExpr attr_value_;
};
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
}
namespace transform {
Pass LiftAttrScope(String attr_key) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope);
} // namespace transform
} // namespace tir
} // namespace tvm