blob: e8a1b564a43be4662e0432915ddcaae917e8eb17 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility>
#include <vector>
#include "arg_binder.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
namespace {
class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var) : ret_var_(ret_var) {}
Stmt VisitStmt_(const ForNode* node) override {
if (node->kind == ForKind::kParallel) in_parallel_ += 1;
Stmt ret = StmtMutator::VisitStmt_(node);
if (node->kind == ForKind::kParallel) in_parallel_ -= 1;
return ret;
}
Stmt VisitStmt_(const EvaluateNode* node) override {
Stmt ret = StmtMutator::VisitStmt_(node);
const EvaluateNode* eval = ret.as<EvaluateNode>();
ICHECK(eval);
if (const CallNode* call = eval->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope.";
ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
ret = WriteToOut(call->args[0]);
}
}
return ret;
}
private:
struct ConvertedInfo {
int type_index{-1};
PrimExpr expr;
};
ConvertedInfo ConvertForFFI(PrimExpr val) {
ConvertedInfo info;
// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_bool()) {
info.type_index = ffi::TypeIndex::kTVMFFIBool;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_int() || dtype.is_uint()) {
info.type_index = ffi::TypeIndex::kTVMFFIInt;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) {
info.type_index = ffi::TypeIndex::kTVMFFIFloat;
info.expr = Cast(DataType::Float(64), val);
} else if (dtype.is_void()) {
info.type_index = ffi::TypeIndex::kTVMFFINone;
info.expr = val;
} else {
LOG(FATAL) << "data type " << dtype << " not supported yet";
}
return info;
}
Stmt WriteToOut(PrimExpr val) {
auto info = ConvertForFFI(val);
Stmt store_tindex =
tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex),
IntImm(DataType::Int(32), info.type_index)}));
Stmt store_zero_padding =
tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding),
IntImm(DataType::Int(32), 0)}));
Stmt store_val = tir::Evaluate(
tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue), info.expr}));
Stmt ret_zero = Evaluate(tvm::ret(0));
return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero});
}
Var ret_var_;
int in_parallel_{0};
};
class SubroutineCallRewriter : public StmtExprMutator {
public:
static ffi::Optional<Stmt> Apply(const ffi::Map<GlobalVar, ffi::String>& packed_func_methods,
Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(std::move(stmt));
if (rewriter.made_change_) {
return stmt;
} else {
return std::nullopt;
}
}
private:
explicit SubroutineCallRewriter(const ffi::Map<GlobalVar, ffi::String>& packed_func_methods)
: packed_func_methods(packed_func_methods) {}
PrimExpr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (auto* gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = ffi::GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) {
ffi::Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value()));
for (auto arg : node->args) {
cpacked_args.push_back(arg);
}
// push an empty handle to be compatible with current cpacked convention
cpacked_args.push_back(tir::make_zero(DataType::Handle()));
made_change_ = true;
return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), cpacked_args);
}
}
return node;
}
const ffi::Map<GlobalVar, ffi::String>& packed_func_methods;
bool made_change_{false};
};
} // namespace
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
}
/* \brief Return the global_symbol of the function, if it should be updated
*
* \param func The function to be inspected
*
* \returns The global_symbol to be used for the function at call
* sites, or std::nullopt if the function is to remain unchanged.
*/
ffi::Optional<ffi::String> RequiresPackedAPI(const PrimFunc& func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return std::nullopt;
}
}
// Internal function calls do not need the ffi::Function API
auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.has_value()) {
return std::nullopt;
}
return global_symbol.value();
}
PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.has_value()) {
return func;
}
std::string name_hint = global_symbol.value();
Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();
// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}
auto* func_ptr = func.CopyOnWrite();
// set the global symbol to the packed function name
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
// Data field definitions
// The packed fields
Var v_self_handle("self_handle", DataType::Handle());
Var v_packed_args("args", DataType::Handle());
Var v_num_packed_args("num_args", DataType::Int(32));
Var v_result("result", PointerType(PrimType(DataType::Void())));
// The device context
Var device_id("dev_id");
Integer device_type(target_device_type);
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
std::unordered_map<const VarNode*, PrimExpr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_load_arg_value = [&](DataType arg_type, int i) {
ffi::Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)};
// load 64 bit version
DataType api_type = APIType(arg_type);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != arg_type) {
res = Cast(arg_type, res);
}
return res;
};
// Assert correct type codes for each argument. This must be done
// *before* any initialization steps produced by
// `binder.BindDLTensor()`. The validity of those initialization
// steps depends on the correct types being present, and must not
// occur before the type codes are actually checked.
seq_init.push_back(MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string {
std::ostringstream error_message;
error_message << name_hint << ": num_args should be " << num_args;
return error_message.str();
}()));
if (num_args > 0) {
seq_init.push_back(MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL"));
}
// Need to delay binding of the buffers, in case some arguments also
// appear in the buffer.
std::vector<std::pair<PrimExpr, Var>> var_def;
std::vector<std::pair<Var, Buffer>> buffer_def;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
PrimExpr arg_value;
// type index checks
Var type_index(param->name_hint + ".type_index", DataType::Int(32));
seq_init.push_back(LetStmt(type_index,
tir::Call(DataType::Int(32), builtin::tvm_struct_get(),
{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}),
nop));
DataType dtype = param.dtype();
if (dtype.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_init.emplace_back(AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
tvm::tir::StringImm(msg.str()), nop));
// if type_index is Tensor, we need to add the offset of the DLTensor header
// which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor*
const int64_t object_cell_offset = sizeof(TVMFFIObject);
static_assert(object_cell_offset == 24);
arg_value = f_load_arg_value(param.dtype(), i);
PrimExpr handle_from_tensor =
Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
{arg_value, IntImm(DataType::Int(32), object_cell_offset)});
arg_value =
Select(type_index == ffi::TypeIndex::kTVMFFITensor, handle_from_tensor, arg_value);
} else if (dtype.is_bool()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be boolean";
seq_init.emplace_back(AssertStmt(
type_index == ffi::TypeIndex::kTVMFFIBool || type_index == ffi::TypeIndex::kTVMFFIInt,
tvm::tir::StringImm(msg.str()), nop));
arg_value = Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i));
} else if (dtype.is_int() || dtype.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(AssertStmt(
type_index == ffi::TypeIndex::kTVMFFIInt || type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
arg_value = f_load_arg_value(param.dtype(), i);
} else {
ICHECK(dtype.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_init.emplace_back(AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
// use select so we can also handle int conversion to bool
arg_value = tir::Select(
type_index == ffi::TypeIndex::kTVMFFIFloat,
/* true_value = */ f_load_arg_value(param.dtype(), i),
/* false_value = */ Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i)));
}
var_def.emplace_back(arg_value, param);
if (func_ptr->buffer_map.count(param)) {
// buffer binding now depends on type index
// if the index is Tensor handle, we need to offset to get the DLTensor*
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}
}
// signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result)
ffi::Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args, v_result};
// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
//
// For example, for auto broadcasting, checks are required to guarantee that
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let binding yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto& [expr, param] : var_def) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}
for (const auto& [var, buffer] : buffer_def) {
binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint);
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
}
// reset global symbol to attach prefix
func = WithAttrs(
std::move(func),
{{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host},
{tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}});
Stmt body = ReturnRewriter(v_result)(func_ptr->body);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
ffi::Any node = ffi::String("default");
seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
body = SeqStmt({set_device, body});
}
}
// Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))});
body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts(), arg_buffer_declarations}, body);
func_ptr->body = body;
func_ptr->params = args;
ffi::Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
<< " are used, but are not passed in as API arguments";
func_ptr->buffer_map = ffi::Map<Var, Buffer>();
func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return func;
}
namespace transform {
Pass MakePackedAPI() {
auto pass_func = [](IRModule mod, PassContext ctx) {
ffi::Map<GlobalVar, ffi::String> packed_func_methods;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
packed_func_methods.Set(gvar, global_symbol.value());
}
}
}
IRModuleNode* mptr = mod.CopyOnWrite();
IRModule updates;
for (const auto& [gvar, base_func] : mptr->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
auto orig_func = func;
if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
func.CopyOnWrite()->body = body.value();
}
func = MakePackedAPI(std::move(func));
if (!func.same_as(orig_func)) {
updates->Add(gvar, func);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.MakePackedAPI", []() { return MakePackedAPI(); });
}
} // namespace transform
} // namespace tir
} // namespace tvm