blob: c7184e07a036f63c2a01bcdaae257ec2261e47bf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file remap_thread_axis.cc
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
namespace tvm {
namespace tir {
// Mutator to change the read pattern
class ThreadAxisRewriter : private StmtExprMutator {
public:
explicit ThreadAxisRewriter(const std::unordered_map<std::string, IterVar>& tmap) : tmap_(tmap) {}
Stmt Rewrite(Stmt stmt) { return operator()(std::move(stmt)); }
private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
auto it = tmap_.find(iv->thread_tag);
if (it != tmap_.end()) {
const IterVar& new_iv = it->second;
const VarNode* v = iv->var.get();
if (!vmap_.count(v)) {
vmap_[v] = new_iv->var;
} else {
ICHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->VisitStmt(op->body);
return AttrStmt(new_iv, op->attr_key, op->value, body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second;
return StmtExprMutator::VisitExpr_(op);
}
// The thread map
const std::unordered_map<std::string, IterVar>& tmap_;
// variable map
std::unordered_map<const VarNode*, Var> vmap_;
};
PrimFunc RemapThreadAxis(PrimFunc func, ffi::Map<ffi::String, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
tmap[kv.first] = kv.second;
}
if (auto opt = func->GetAttr<ffi::Array<IterVar>>(tir::attr::kKernelLaunchParams)) {
ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams;
auto launch_params = opt.value();
// replace the thread axis attribute
for (size_t i = 0; i < launch_params.size(); ++i) {
auto it = tmap.find(launch_params[i]->thread_tag);
if (it != tmap.end()) {
launch_params.Set(i, it->second);
}
}
func = WithAttr(std::move(func), tir::attr::kKernelLaunchParams, launch_params);
}
auto* n = func.CopyOnWrite();
n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body));
return func;
}
namespace transform {
Pass RemapThreadAxis(ffi::Map<ffi::String, IterVar> thread_map) {
auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
return RemapThreadAxis(std::move(f), thread_map);
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.RemapThreadAxis", RemapThreadAxis);
}
} // namespace transform
} // namespace tir
} // namespace tvm