blob: 68b494d4114408e6bacc0630083d3c18c93985a8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tir/ir/transform.cc
* \brief TIR specific transformation passes.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
#include <tvm/node/repr_printer.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
namespace transform {
// Register build pipeline related options
TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", ffi::Array<ffi::Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
/*!
* \brief Function level pass that applies transformations to all
* TIR functions within the module.
*/
class PrimFuncPassNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief The pass function called on each. */
std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PrimFuncPassNode>().def_ro("pass_info", &PrimFuncPassNode::pass_info);
}
/*!
* \brief Run a function pass on given pass context.
*
* \param mod The module that an optimization pass is applied on.
* \param pass_ctx The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.PrimFuncPass", PrimFuncPassNode, PassNode);
};
class PrimFuncPass : public Pass {
public:
/*!
* \brief The constructor
* \param pass_func The packed function which implements a pass.
* \param pass_info The pass info.
*/
TVM_DLL PrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimFuncPass, Pass, PrimFuncPassNode);
};
PrimFuncPass::PrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = ffi::make_object<PrimFuncPassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
// Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
ICHECK(mod.defined());
std::vector<GlobalVar> deleted_list;
IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
// directly loop over the underlying dict
for (auto& kv : *func_dict) {
// only picks up tir::PrimFunc
if (auto opt_func = kv.second.as<PrimFunc>()) {
// reset the original Any state so the value contains only copy
// use move semantics as follows to avoid only copy.
kv.second.reset();
PrimFunc func = *std::move(opt_func);
func = pass_func(std::move(func), mod, pass_ctx);
kv.second = Any(std::move(func));
if (kv.second == nullptr) {
deleted_list.push_back(Downcast<GlobalVar>(kv.first));
}
}
}
// Automatic removal of None. This uses IRModuleNode::Remove
// instead of manipulating func_dict directly, to ensure that both
// the function map and the global_var_map_ are correctly updated.
for (const auto& gv : deleted_list) {
mod_ptr->Remove(gv);
}
return mod;
}
Pass CreatePrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
int opt_level, ffi::String name, tvm::ffi::Array<ffi::String> required,
bool traceable) {
PassInfo pass_info = PassInfo(opt_level, name, required, traceable);
return PrimFuncPass(std::move(pass_func), pass_info);
}
TVM_FFI_STATIC_INIT_BLOCK() { PrimFuncPassNode::RegisterReflection(); }
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tir.transform.CreatePrimFuncPass",
[](ffi::TypedFunction<PrimFunc(ffi::RValueRef<PrimFunc>, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto wrapped_pass_func = [pass_func](PrimFunc func, IRModule mod, PassContext ctx) {
return pass_func(ffi::RValueRef<PrimFunc>(std::move(func)), mod, ctx);
};
return PrimFuncPass(wrapped_pass_func, pass_info);
});
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")";
});
} // namespace transform
} // namespace tir
} // namespace tvm