blob: 5298added83be121a8102922ff62bd4c5852a2a5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/ir/transform.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
namespace tvm {
namespace tir {
// use/def analysis, also delete unreferenced lets
class VarUseDefAnalysis : public StmtExprMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
// use the first appearance as def.
if (!use_count_.count(iv->var.get())) {
this->HandleDef(iv->var.get());
thread_axis_.push_back(iv);
thread_extent_.push_back(op->value);
}
PrimExpr value = op->value;
if (visit_thread_extent_) {
value = this->VisitExpr(value);
}
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
return AttrStmt(op->node, op->attr_key, value, body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const LetStmtNode* op) final {
this->HandleDef(op->var.get());
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
}
}
Stmt VisitStmt_(const ForNode* op) final {
this->HandleDef(op->loop_var.get());
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AllocateNode* op) final {
this->HandleDef(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const StoreNode* op) final {
this->HandleUse(op->buffer_var);
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const LetNode* op) final {
// Weaker SSA condition
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
PrimExpr value = this->VisitExpr(op->value);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, value))
<< "Let cannot bind the same var to two different values";
return GetRef<PrimExpr>(it->second);
} else {
this->HandleDef(op->var.get());
let_binding_[op->var] = op;
}
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
simplify_let_) {
return body;
} else {
if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
}
}
PrimExpr VisitExpr_(const VarNode* op) final {
this->HandleUse(GetRef<PrimExpr>(op));
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const ReduceNode* op) final {
for (const auto& iv : op->axis) {
this->HandleDef(iv->var.get());
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
this->HandleUse(op->buffer_var);
return StmtExprMutator::VisitExpr_(op);
}
void HandleDef(const VarNode* v) {
CHECK(!def_count_.count(v)) << "variable " << v->name_hint
<< " has already been defined, the Stmt is not SSA";
CHECK(!use_count_.count(v)) << "variable " << v->name_hint
<< " has been used before definition!";
use_count_[v] = 0;
def_count_[v] = 1;
}
void HandleUse(const PrimExpr& v) {
CHECK(v.as<VarNode>());
Var var = Downcast<Var>(v);
auto it = use_count_.find(var.get());
if (it != use_count_.end()) {
if (it->second >= 0) {
++it->second;
}
} else {
undefined_.push_back(var);
use_count_[var.get()] = -1;
}
}
// The fields are publically readible to
// be accessible to the users.
bool visit_thread_extent_{true};
bool simplify_let_{true};
Array<Var> undefined_;
Array<IterVar> thread_axis_;
Array<PrimExpr> thread_extent_;
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;
private:
ExprDeepEqual deep_equal_;
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
VarUseDefAnalysis m;
m.simplify_let_ = false;
for (Var arg : args) {
m.use_count_[arg.get()] = 0;
}
m(stmt);
return m.undefined_;
}
Array<Var> UndefinedVars(const PrimExpr& expr) {
VarUseDefAnalysis m;
m.simplify_let_ = false;
m(expr);
return m.undefined_;
}
class HostDeviceSplitter : public StmtMutator {
public:
explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix)
: device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {}
Stmt VisitStmt_(const AllocateNode* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return StmtMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
return SplitDeviceFunc(GetRef<Stmt>(op));
}
return StmtMutator::VisitStmt_(op);
}
private:
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_prefix_ << "_kernel" << device_func_counter_++;
std::string kernel_symbol = os.str();
// isolate the device function.
VarUseDefAnalysis m;
m.visit_thread_extent_ = false;
body = m(std::move(body));
Array<Var> params;
Array<PrimExpr> arguments;
Map<tir::Var, PrimExpr> remap_vars;
// Strictly order the arguments: Var pointers, positional arguments.
for (Var var : m.undefined_) {
if (var.dtype().is_handle()) {
// Create a new version of v.
auto it = handle_data_type_.find(var.get());
if (it != handle_data_type_.end()) {
tir::Var new_var(var->name_hint, PointerType(PrimType((*it).second->dtype)));
params.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
params.push_back(var);
}
arguments.push_back(var);
}
}
// positional arguments
for (Var var : m.undefined_) {
if (!var.dtype().is_handle()) {
params.push_back(var);
arguments.push_back(var);
}
}
PrimFunc device_func(params, Substitute(body, remap_vars));
device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_);
device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
Integer(CallingConv::kDeviceKernelLaunch));
device_func =
WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
(*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
// generate calls to the device function
Array<PrimExpr> call_args;
call_args.push_back(StringImm(kernel_symbol));
for (PrimExpr arg : arguments) {
call_args.push_back(arg);
}
for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args));
}
// target ir module
IRModule* device_mod_;
// Device target
Target device_target_;
// function name hint
std::string name_prefix_;
// Number of device functions.
int device_func_counter_{0};
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) << "SplitHostDevice: Require the target attribute";
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
HostDeviceSplitter splitter(device_mod, target.value(),
static_cast<std::string>(global_symbol.value()));
auto* n = func.CopyOnWrite();
n->body = splitter(std::move(n->body));
// set the host target to None.
func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
return std::move(func);
}
namespace transform {
Pass SplitHostDevice() {
auto pass_func = [](IRModule mod, PassContext ctx) {
IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
IRModule device_mod = IRModule();
for (auto& kv : *func_dict) {
if (kv.second->IsInstance<PrimFuncNode>()) {
PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
kv.second = SplitHostDevice(std::move(func), &device_mod);
}
}
mod->Update(device_mod);
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {});
}
TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice);
} // namespace transform
} // namespace tir
} // namespace tvm