blob: 3fae2bbf40c89c32c52f5fdc0dfb21d66c916d67 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility>
#include <vector>
#include "arg_binder.h"
#include "ir_util.h"
namespace tvm {
namespace tir {
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
int target_device_type = target.value()->kind->device_type;
std::string name_hint = global_symbol.value();
auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
CHECK_LE(num_unpacked_args, num_args);
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
Var v_packed_args("args", DataType::Handle());
Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle());
Var v_num_packed_args("num_args", DataType::Int(32));
Var v_out_ret_value("out_ret_value", DataType::Handle());
Var v_out_ret_tcode("out_ret_tcode", DataType::Handle());
Var v_resource_handle("resource_handle", DataType::Handle());
// The arguments of the function.
Array<Var> args;
// The device context
Var device_id("dev_id");
Integer device_type(target_device_type);
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check;
std::unordered_map<const VarNode*, PrimExpr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != t) {
res = Cast(t, res);
}
return res;
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
if (num_packed_args != 0) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
std::ostringstream os;
os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
// Need to re-declare vars, in case some arguments also appears in the buffer.
std::vector<std::pair<Var, Var> > var_def;
std::vector<std::pair<Var, Buffer> > buffer_def;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
Var v_arg = Var("arg" + std::to_string(i), param->dtype);
auto it = func_ptr->buffer_map.find(param);
if (it != func_ptr->buffer_map.end()) {
buffer_def.emplace_back(v_arg, (*it).second);
} else {
var_def.emplace_back(v_arg, param);
}
if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt(tcode,
Load(DataType::Int(32), v_packed_arg_type_ids,
IntImm(DataType::Int(32), i), const_true(1)),
nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
}
} else {
args.push_back(v_arg);
}
}
// allow return value if the function is packed.
if (num_packed_args != 0) {
args.push_back(v_out_ret_value);
args.push_back(v_out_ret_tcode);
args.push_back(v_resource_handle);
}
size_t expected_nargs = num_unpacked_args + (num_packed_args != 0 ? 6 : 0);
CHECK_EQ(args.size(), expected_nargs);
// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
//
// For example, for auto broadcasting, checks are required to guarantee that
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let bining yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto& kv : var_def) {
binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
}
for (const auto& kv : buffer_def) {
binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint);
}
if (num_unpacked_args == 0) {
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
StringImm(name_hint + "_compute_"), func_ptr->body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImm("default");
seq_check.push_back(AttrStmt(node, attr::device_context_id, device_id, nop));
seq_check.push_back(AttrStmt(node, attr::device_context_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
body = SeqStmt({set_device, body});
}
}
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
os << " is not bound to any variables";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return std::move(func);
}
namespace transform {
Pass MakePackedAPI(int num_unpacked_args) {
auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
updates.push_back({kv.first, updated_func});
}
}
}
for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
}
return m;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});
}
TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed(MakePackedAPI);
} // namespace transform
} // namespace tir
} // namespace tvm