blob: 1a1279f0640ad00881067ab161a3c5713aca60de [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* \brief Lift specified AttrStmt scope to outer if
* the body contains the same scope.
* \file
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "ir_util.h"
namespace tvm {
namespace tir {
// NOTE: this optimization can only be applied
// to a few specified attr keys
class AttrScopeLifter : public StmtMutator {
explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {}
Stmt Lift(Stmt stmt) {
stmt = operator()(std::move(stmt));
if (attr_node_.defined()) {
stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt);
return stmt;
// do not go beyond
Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op =<AllocateNode>();
if (attr_node_.defined()) {
Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body);
} else {
return stmt;
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr_key_) {
attr_node_ = op->node;
attr_value_ = op->value;
return op->body;
} else {
return StmtMutator::VisitStmt_(op);
Stmt VisitStmt_(const SeqStmtNode* op) final {
// remember the decorations.
std::vector<ObjectRef> attr_node;
std::vector<PrimExpr> attr_value;
auto fmutate = [&](const Stmt& s) {
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
Stmt ret = this->VisitStmt(s);
return ret;
Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate);
if (attr_node.size() == 0) return ret;
op =<SeqStmtNode>();
CHECK(op != nullptr);
Array<Stmt> reorg;
// check if all decorations are common.
for (size_t begin = 0; begin < attr_node.size();) {
size_t end = begin + 1;
while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) &&
ValueSame(attr_value[end], attr_value[begin])) {
// covers everything
// lift attr to parent.
if (begin == 0 && end == attr_node.size()) {
attr_node_ = attr_node[0];
attr_value_ = attr_value[0];
return ret;
// construct subsegments.
Array<Stmt> seq;
for (size_t i = begin; i < end; ++i) {
Stmt stmt = SeqStmt::Flatten(seq);
if (attr_node[begin].defined()) {
stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt);
begin = end;
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
return SeqStmt::Flatten(reorg);
Stmt VisitStmt_(const IfThenElseNode* op) final {
if (!op->else_case.defined()) {
return StmtMutator::VisitStmt_(op);
Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node;
PrimExpr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->VisitStmt(op->else_case);
if (attr_node_.defined() && attr_value_.defined() && first_node.defined() &&
first_value.defined() && attr_node_.same_as(first_node) &&
ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElse(op->condition, then_case, else_case);
} else {
if (first_node.defined()) {
then_case = AttrStmt(first_node, attr_key_, first_value, then_case);
if (attr_node_.defined()) {
else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElse(op->condition, then_case, else_case);
// value comparison that also compares content of int constant
static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
if (a.same_as(b)) return true;
if (!a.defined() || !b.defined()) return false;
if (a->type_index() != b->type_index()) return false;
if (a.dtype() != b.dtype()) return false;
if (const IntImmNode* op =<IntImmNode>()) {
return op->value ==<IntImmNode>()->value;
return false;
std::string attr_key_;
ObjectRef attr_node_;
PrimExpr attr_value_;
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
namespace transform {
Pass LiftAttrScope(String attr_key) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body));
return f;
return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
} // namespace transform
} // namespace tir
} // namespace tvm