blob: c8df122d40b5c1e8b6d685b1cecf06f0384bf536 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Lower TVM related builtin intrinsics such as packed call.
* \file tir/transforms/lower_tvm_buildin.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "ir_util.h"
namespace tvm {
namespace tir {
inline PrimExpr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(DataType::Int(32), static_cast<int>(index));
}
inline PrimExpr StackAlloca(std::string type, size_t num) {
Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
}
// Calculate the statistics of packed function.
// These information are needed during codegen.
class BuiltinLower : public StmtExprMutator {
public:
Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", DataType::Handle());
stack_array_ = Var("stack_array", DataType::Handle());
stack_value_ = Var("stack_value", DataType::Handle());
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->VisitStmt(stmt);
// create a shape var if any shape is made (including scalar shapes)
if (max_shape_stack_ != -1) {
stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
stmt = LetStmt(stack_array_, StackAlloca("array", max_array_stack_), stmt);
}
if (max_arg_stack_ != 0) {
stmt = LetStmt(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
stmt = LetStmt(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
}
return stmt;
}
Stmt VisitStmt(const Stmt& s) final {
auto stmt = StmtExprMutator::VisitStmt(s);
CHECK_EQ(run_shape_stack_, -1);
CHECK_EQ(run_array_stack_, 0);
if (prep_seq_.size() != 0) {
Stmt ret = SeqStmt::Flatten(prep_seq_, stmt);
prep_seq_.clear();
return ret;
} else {
return stmt;
}
}
Stmt VisitStmt_(const AllocateNode* op) {
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
// Get constant allocation bound.
int64_t nbytes = GetVectorBytes(op->dtype);
if (device_type_.defined()) {
if (const auto* dev_type = device_type_.as<IntImmNode>()) {
if (dev_type->value == kDLCPU) {
int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt;
}
}
}
}
PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
total_bytes = total_bytes * op->extents[i];
}
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
throw_last_error),
op->body});
Stmt alloca = LetStmt(
op->buffer_var,
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())}),
body);
PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_), op->buffer_var});
Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
body = AttrStmt(op->buffer_var, attr::storage_alignment,
make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
return body;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined());
device_id_ = op->value;
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::device_context_type) {
CHECK(!device_type_.defined());
device_type_ = op->value;
return this->VisitStmt(op->body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_call_packed())) {
return MakeCallPacked(op);
} else if (op->op.same_as(builtin::tvm_call_trace_packed())) {
return MakeCallTracePacked(op);
} else if (op->op.same_as(builtin::tvm_stack_make_shape())) {
return MakeShape(op);
} else if (op->op.same_as(builtin::tvm_stack_make_array())) {
return MakeArray(op);
} else if (op->op.same_as(builtin::tvm_context_id())) {
return make_zero(op->dtype);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
// call shape
PrimExpr MakeShape(const CallNode* op) {
// if args.size() == 0, it represents a scalar shape ()
if (run_shape_stack_ == -1) {
run_shape_stack_ = 0;
}
int64_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
// no need to perform any store for a scalar shape
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin + i), const_true(1)));
}
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
// make array
PrimExpr MakeArray(const CallNode* op) {
size_t idx = run_array_stack_;
run_array_stack_ += 1;
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrData, op->args[0]));
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrShape, op->args[1]));
PrimExpr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(DataType::Handle());
}
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrStrides, strides));
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrNDim, op->args[3]));
DataType dtype = op->args[4].dtype();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, builtin::kArrTypeCode,
make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeBits,
make_const(DataType::UInt(8), dtype.bits())));
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeLanes,
make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
PrimExpr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrByteOffset,
cast(DataType::UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceId,
cast(DataType::Int(32), device_id_)));
prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceType,
cast(DataType::Int(32), device_type_)));
return TVMStructGet(DataType::Handle(), stack_array_, idx, builtin::kArrAddr);
}
// call packed.
PrimExpr MakeCallPacked(const CallNode* op) {
int64_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq_.emplace_back(
Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
run_shape_stack_ = restore_shape_stack;
run_array_stack_ = restore_array_stack;
run_arg_stack_ = arg_stack_begin;
Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};
return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args);
}
PrimExpr MakeCallTracePacked(const CallNode* op) {
int64_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
size_t args_size = op->args.size();
CHECK_GT(args_size, 0);
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq_.emplace_back(
Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
run_shape_stack_ = restore_shape_stack;
run_array_stack_ = restore_array_stack;
// Update the top of the stack, so we can use more than one
// packed function's arguments with the one stack.
run_arg_stack_ = arg_stack_begin + args_size - 1;
Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1),
// Pass traced value.
op->args[args_size - 1]};
return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args);
}
private:
bool IsArrayHandle(const PrimExpr& arg) {
// specially set array handle.
if (const CallNode* buf = arg.as<CallNode>()) {
if (buf->op.same_as(builtin::tvm_struct_get()) &&
buf->args[2].as<IntImmNode>()->value == builtin::kArrAddr) {
return true;
}
}
return false;
}
// The prepration sequence to be emitted.
std::vector<Stmt> prep_seq_;
PrimExpr device_type_;
PrimExpr device_id_;
// Var handle for each stack.
Var stack_shape_;
Var stack_array_;
Var stack_tcode_;
Var stack_value_;
// The running statistics
int64_t run_shape_stack_{-1};
uint64_t run_array_stack_{0};
uint64_t run_arg_stack_{0};
// statistics of stacks
int64_t max_shape_stack_{-1};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
};
namespace transform {
Pass LowerTVMBuiltin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = BuiltinLower().Build(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin);
} // namespace transform
} // namespace tir
} // namespace tvm