blob: 62c790fab3ab6925b22609d568730779cc08ff8d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tir/ir/transform.cc
* \brief TIR specific transformation passes.
*/
#include <tvm/node/repr_printer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
namespace transform {
/*!
* \brief Function level pass that applies transformations to all
* TIR functions within the module.
*/
class PrimFuncPassNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief The pass function called on each. */
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
/*!
* \brief Run a function pass on given pass context.
*
* \param mod The module that an optimization pass is applied on.
* \param pass_ctx The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
static constexpr const char* _type_key = "tir.PrimFuncPass";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode);
};
class PrimFuncPass : public Pass {
public:
/*!
* \brief The constructor
* \param pass_func The packed function which implements a pass.
* \param pass_info The pass info.
*/
TVM_DLL PrimFuncPass(
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info);
TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
};
PrimFuncPass::PrimFuncPass(
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<PrimFuncPassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
// Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
std::vector<ObjectRef> deleted_list;
IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
// directly loop over the underlying dict
for (auto& kv : *func_dict) {
// only picks up tir::PrimFunc
if (kv.second->IsInstance<PrimFuncNode>()) {
// move out the function so that it is the only copy.
PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
func = pass_func(std::move(func), mod, pass_ctx);
kv.second = std::move(func);
if (!kv.second.defined()) {
deleted_list.push_back(kv.first);
}
}
}
// automatic removal of None
for (const auto& gv : deleted_list) {
func_dict->erase(gv);
}
pass_ctx.Trace(mod, pass_info, false);
return mod;
}
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return PrimFuncPass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(PrimFuncPassNode);
TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass")
.set_body_typed(
[](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")";
});
} // namespace transform
} // namespace tir
} // namespace tvm