blob: 114bdbc80629eb23638282a7a68159eaafdefd2f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_spirv.cc
* \brief Generate SPIRV block
*/
#include "codegen_spirv.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include "../../runtime/pack_args.h"
#include "../../runtime/vulkan/vulkan_common.h"
#include "../../tir/transform/ir_utils.h"
namespace tvm {
namespace codegen {
CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {}
runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) {
this->InitFuncState();
TVM_FFI_ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias))
<< "SPIRV only takes restricted memory model";
std::vector<Var> pod_args;
uint32_t i_buffer = 0;
// Currently, all storage and uniform buffer arguments are passed as
// a single descriptor set at index 0. If ever non-zero, must
// ensure it is less than maxBoundDescriptorSets.
const uint32_t descriptor_set = 0;
for (Var arg : f->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
TVM_FFI_ICHECK(ptr)
<< "All handles passed to the Vulkan codegen must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
auto* prim = ptr->element_type.as<PrimTypeNode>();
TVM_FFI_ICHECK(prim)
<< "All handles passed to the Vulkan codegen must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::Bool()) {
// We need a physically addressable buffer type to support boolean tensors.
// The loaded byte is cast to bool inside the LoadNode visitor below.
value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes());
}
spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
descriptor_set, i_buffer++);
builder_->SetName(arg_value, arg->name_hint);
storage_info_[arg.get()].SetContentType(value_storage_type, arg->name_hint);
var_map_[arg.get()] = arg_value;
} else {
pod_args.push_back(arg);
}
}
spirv::Value func_ptr = builder_->NewFunction();
builder_->StartFunction(func_ptr);
runtime::SPIRVShader shader;
if (pod_args.size() != 0) {
std::vector<spirv::SType> value_types;
for (size_t i = 0; i < pod_args.size(); ++i) {
value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
}
if (pod_args.size() * sizeof(runtime::ArgUnion64) <= runtime::vulkan::kMaxPushConstantsBytes) {
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value =
builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
}
} else {
shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO;
// If we need to pass more arguments than push constants could handle, we use UBO.
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, i_buffer++);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
}
}
}
this->VisitStmt(f->body);
builder_->SetLocalSize(func_ptr, workgroup_size_);
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);
builder_->CommitKernelFunction(func_ptr, name);
TVM_FFI_ICHECK_LE(shared_memory_bytes_used_, spirv_support_.max_shared_memory_per_block)
<< "Vulkan shader " << name << " uses " << shared_memory_bytes_used_
<< " bytes of shared memory, "
<< "but target supports only " << spirv_support_.max_shared_memory_per_block << " bytes. "
<< "If the device supports this allocation, "
<< "please add -max_shared_memory_per_block=NBYTES to the target, "
<< "or query all device parameters by adding -from_device=0.";
shader.data = builder_->Finalize();
return shader;
}
void CodeGenSPIRV::InitFuncState() {
std::fill(workgroup_size_, workgroup_size_ + 3, 1);
var_map_.clear();
storage_info_.clear();
analyzer_.reset(new arith::Analyzer());
builder_.reset(new spirv::IRBuilder(spirv_support_));
builder_->InitHeader();
shared_memory_bytes_used_ = 0;
fragment_info_.clear();
}
spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) {
runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
spirv::Value v;
if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index);
auto* sizeptr = extent.as<tir::IntImmNode>();
TVM_FFI_ICHECK(sizeptr) << "SPIRV only allows constant thread group size "
<< " get " << extent;
TVM_FFI_ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here";
TVM_FFI_ICHECK_LT(ts.dim_index, 3);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
} else {
v = builder_->GetWorkgroupID(ts.dim_index);
}
return builder_->Cast(builder_->GetSType(iv->var.dtype()), v);
}
spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
spirv::Value value;
uint32_t vulkan_api_version = spirv_support_.vulkan_api_version;
int64_t sync_scope;
int64_t memory_semantics = spv::MemorySemanticsSequentiallyConsistentMask;
if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) {
// Synchronize control at the Subgroup level, but memory at the
// Workgroup level. This is because different invocations in a
// subgroup may have each modified memory that exists at the
// workgroup scope. This should be changed if/when tir exposes
// more information as to which memory access needs to be
// synchronized.
sync_scope = spv::ScopeSubgroup;
memory_semantics |=
spv::MemorySemanticsSubgroupMemoryMask | spv::MemorySemanticsWorkgroupMemoryMask;
} else if ((sync == "shared") || (sync == "warp")) {
sync_scope = spv::ScopeWorkgroup;
memory_semantics |= spv::MemorySemanticsWorkgroupMemoryMask;
} else {
TVM_FFI_THROW(InternalError) << "Do not support sync " << sync;
}
auto type_int = builder_->GetSType(DataType::Int(32));
builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, memory_semantics));
return value;
}
spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) {
auto it = var_map_.find(op);
TVM_FFI_ICHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
return it->second;
}
spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) {
return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) {
TVM_FFI_THROW(InternalError) << "StringImm is not supported in Device code";
}
spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) {
return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) {
return builder_->Add(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const SubNode* op) {
return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const MulNode* op) {
return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const DivNode* op) {
return builder_->Div(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) {
return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const MinNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->Select(builder_->LT(a, b), a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const MaxNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->Select(builder_->GT(a, b), a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const LTNode* op) {
return builder_->LT(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const LENode* op) {
return builder_->LE(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const GTNode* op) {
return builder_->GT(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const GENode* op) {
return builder_->GE(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const EQNode* op) {
return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const NENode* op) {
return builder_->NE(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const AndNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const OrNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) {
spirv::Value a = MakeValue(op->a);
return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
return builder_->Select(MakeValue(op->condition), MakeValue(op->true_value),
MakeValue(op->false_value));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
TVM_FFI_ICHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::call_spirv_pure_glsl450())) {
TVM_FFI_ICHECK_GE(op->args.size(), 2U);
uint32_t inst_id = static_cast<uint32_t>(op->args[0].as<IntImmNode>()->value);
std::vector<spirv::Value> values;
for (size_t i = 1; i < op->args.size(); ++i) {
values.push_back(MakeValue(op->args[i]));
}
return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values);
} else if (op->op.same_as(builtin::bitwise_and())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
} else if (op->op.same_as(builtin::bitwise_xor())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
} else if (op->op.same_as(builtin::bitwise_or())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
} else if (op->op.same_as(builtin::bitwise_not())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 1U);
spirv::Value a = MakeValue(op->args[0]);
return builder_->MakeValue(spv::OpNot, a.stype, a);
} else if (op->op.same_as(builtin::shift_left())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
} else if (op->op.same_as(builtin::shift_right())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
if (op->args[0].dtype().is_int()) {
return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b);
} else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
}
} else if (op->op.same_as(builtin::reinterpret())) {
return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else if (op->op.same_as(builtin::large_uint_imm())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
return builder_->UIntImm(builder_->GetSType(op->dtype), val);
} else if (op->op.same_as(builtin::tvm_storage_sync())) {
return this->CreateStorageSync(op);
} else if (op->op.same_as(builtin::if_then_else())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 3U);
spirv::Value cond = MakeValue(op->args[0]);
spirv::Label then_label = builder_->NewLabel();
spirv::Label else_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
// then block, must get label after we see the value
builder_->StartLabel(then_label);
spirv::Value then_value = MakeValue(op->args[1]);
spirv::Label then_value_label = builder_->CurrentLabel();
builder_->MakeInst(spv::OpBranch, merge_label);
// else block
builder_->StartLabel(else_label);
spirv::Value else_value = MakeValue(op->args[2]);
spirv::Label else_value_label = builder_->CurrentLabel();
builder_->MakeInst(spv::OpBranch, merge_label);
// merge block
builder_->StartLabel(merge_label);
spirv::PhiValue phi = builder_->MakePhi(then_value.stype, 2);
phi.SetIncoming(0, then_value, then_value_label);
phi.SetIncoming(1, else_value, else_value_label);
return phi;
} else if (op->op.same_as(builtin::popcount())) {
return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else if (op->op.same_as(builtin::call_pure_extern())) {
TVM_FFI_ICHECK_GE(op->args.size(), 1U);
const std::string& func_name = op->args[0].as<StringImmNode>()->value;
if (func_name == "__dp4a") {
std::vector<spirv::Value> values;
for (size_t i = 1; i < op->args.size(); ++i) {
values.push_back(MakeValue(op->args[i]));
}
return builder_->CallKHRIntegerDotProduct(builder_->GetSType(op->dtype), values, op->dtype);
} else {
TVM_FFI_THROW(InternalError)
<< "SPIR-V shader cannot make extern calls. Graph contains extern \""
<< Downcast<StringImm>(op->args[0]) << "\"";
return spirv::Value();
}
} else if (op->op.same_as(builtin::call_extern())) {
TVM_FFI_ICHECK_GE(op->args.size(), 1U);
TVM_FFI_THROW(InternalError)
<< "SPIR-V shader cannot make extern calls. Graph contains extern \""
<< Downcast<StringImm>(op->args[0]) << "\"";
return spirv::Value();
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 6U);
const VarNode* buffer_node = op->args[0].as<VarNode>();
TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node));
DataType ele_dtype = GetElementDataType(buffer_node);
TVM_FFI_ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported";
spirv::SType ele_stype = builder_->GetSType(ele_dtype);
spirv::SType& fragment_type = fragment_info_[buffer_node].stype;
double init = static_cast<uint64_t>(Downcast<FloatImm>(op->args[5])->value);
PrimExpr prim_index = op->args[4];
spirv::Value init_val = builder_->GetCompositeConst(ele_stype, fragment_type, init);
spirv::SType ptr_type =
builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass);
spirv::Value index = MakeValue(prim_index);
TVM_FFI_ICHECK(var_map_.count(buffer_node));
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], index);
builder_->MakeInst(spv::OpStore, ptr, init_val, spv::MemoryAccessMaskNone);
return spirv::Value();
} else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_node = op->args[0].as<VarNode>();
TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node));
spirv::SType& fragment_type = fragment_info_[buffer_node].stype;
PrimExpr dst_index = op->args[4];
PrimExpr src_ptr_expr = op->args[5];
int stride = static_cast<int>(Downcast<IntImm>(op->args[6])->value);
auto type_int = builder_->GetSType(DataType::Int(32));
spirv::Value stride_val = builder_->IntImm(type_int, stride);
std::string layout = (op->args[7].as<StringImmNode>())->value;
spirv::SType dst_ptr_type =
builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass);
spirv::Value dst_ptr =
builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index));
spirv::Value src_ptr = VisitExpr(op->args[5]);
spirv::SType type_bool = builder_->GetSType(DataType::Bool());
spirv::Value t_val = builder_->UIntImm(type_bool, 1);
spirv::Value f_val = builder_->UIntImm(type_bool, 0);
spirv::Value loaded =
builder_->MakeValue(spv::OpCooperativeMatrixLoadNV, fragment_type, src_ptr, stride_val,
(layout != "row_major") ? t_val : f_val);
builder_->MakeInst(spv::OpStore, dst_ptr, loaded, spv::MemoryAccessMaskNone);
return spirv::Value();
} else if (op->op.same_as(builtin::tvm_mma_sync())) {
const VarNode* buffer_d = op->args[0].as<VarNode>();
const VarNode* buffer_a = op->args[2].as<VarNode>();
const VarNode* buffer_b = op->args[4].as<VarNode>();
const VarNode* buffer_c = op->args[6].as<VarNode>();
PrimExpr index_d = op->args[1];
PrimExpr index_a = op->args[3];
PrimExpr index_b = op->args[5];
tvm::tir::ExprDeepEqual expr_equal;
PrimExpr index_c = op->args[7];
bool is_equal = ((buffer_d == buffer_c) && expr_equal(index_d, index_c));
spirv::SType& fragment_type_d = fragment_info_[buffer_d].stype;
spirv::SType& fragment_type_a = fragment_info_[buffer_a].stype;
spirv::SType& fragment_type_b = fragment_info_[buffer_b].stype;
spirv::SType& fragment_type_c = fragment_info_[buffer_c].stype;
spv::StorageClass storage = fragment_info_[buffer_d].sclass;
spirv::SType ptr_type_d = builder_->GetPointerType(fragment_type_d, storage);
spirv::SType ptr_type_a = builder_->GetPointerType(fragment_type_a, storage);
spirv::SType ptr_type_b = builder_->GetPointerType(fragment_type_b, storage);
spirv::SType ptr_type_c = builder_->GetPointerType(fragment_type_c, storage);
spirv::Value ptr_d =
builder_->StructArrayAccess(ptr_type_d, var_map_[buffer_d], MakeValue(index_d));
spirv::Value ptr_a =
builder_->StructArrayAccess(ptr_type_a, var_map_[buffer_a], MakeValue(index_a));
spirv::Value ptr_b =
builder_->StructArrayAccess(ptr_type_b, var_map_[buffer_b], MakeValue(index_b));
spirv::Value ptr_c =
is_equal ? ptr_d
: builder_->StructArrayAccess(ptr_type_c, var_map_[buffer_c], MakeValue(index_c));
uint32_t mask = spv::MemoryAccessMaskNone;
spirv::Value loaded_a = builder_->MakeValue(spv::OpLoad, fragment_type_a, ptr_a, mask);
spirv::Value loaded_b = builder_->MakeValue(spv::OpLoad, fragment_type_b, ptr_b, mask);
spirv::Value loaded_c = builder_->MakeValue(spv::OpLoad, fragment_type_c, ptr_c, mask);
spirv::Value result = builder_->MakeValue(spv::OpCooperativeMatrixMulAddNV, fragment_type_d,
loaded_a, loaded_b, loaded_c);
builder_->MakeInst(spv::OpStore, ptr_d, result, spv::MemoryAccessMaskNone);
return spirv::Value();
} else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
TVM_FFI_ICHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_node = op->args[0].as<VarNode>();
PrimExpr index = op->args[4];
PrimExpr buffer_ptr = op->args[5];
int stride = static_cast<int>(Downcast<IntImm>(op->args[6])->value);
auto type_int = builder_->GetSType(DataType::Int(32));
spirv::Value stride_val = builder_->IntImm(type_int, stride);
std::string layout = (op->args[7].as<StringImmNode>())->value;
spirv::Value dst_ptr = VisitExpr(op->args[5]);
spirv::SType& fragment_type = fragment_info_[buffer_node].stype;
spv::StorageClass storage = fragment_info_[buffer_node].sclass;
spirv::SType ptr_type = builder_->GetPointerType(fragment_type, storage);
spirv::Value ptr =
builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
uint32_t mask = spv::MemoryAccessMaskNone;
spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask);
spirv::SType type_bool = builder_->GetSType(DataType::Bool());
spirv::Value t_val = builder_->UIntImm(type_bool, 1);
spirv::Value f_val = builder_->UIntImm(type_bool, 0);
builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val,
(layout != "row_major") ? t_val : f_val);
return spirv::Value();
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
Var buffer_var = load->buffer->data;
const VarNode* buffer_node = buffer_var.get();
PrimExpr index = load->indices[0];
DataType ele_dtype = GetElementDataType(buffer_node);
spirv::SType ele_stype = builder_->GetSType(ele_dtype);
spirv::Value buffer_val = MakeValue(buffer_var);
spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class);
TVM_FFI_ICHECK(var_map_.count(buffer_node));
return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
return MakeValue(op->args[0]);
} else {
TVM_FFI_THROW(InternalError) << "Unresolved call " << op->op;
}
}
spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
std::vector<spirv::Value> values;
spirv::Value base = MakeValue(op->base);
int lanes = op->dtype.lanes();
for (int i = 0; i < lanes; ++i) {
spirv::Value v = base;
if (i != 0) {
spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride);
v = builder_->Add(v, offset);
}
values.push_back(v);
}
return builder_->Concat(values);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
std::vector<spirv::Value> values;
spirv::Value v = MakeValue(op->value);
int lanes = op->dtype.lanes();
for (int i = 0; i < lanes; i++) {
values.push_back(v);
}
return builder_->Concat(values);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) {
TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers";
TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported.";
Var buffer_var = op->buffer->data;
PrimExpr prim_index = op->indices[0];
DataType desired_read_type = op->dtype;
if (desired_read_type == DataType::Bool()) {
desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes());
}
auto it = storage_info_.find(buffer_var.get());
TVM_FFI_ICHECK(it != storage_info_.end());
StorageInfo& info = it->second;
info.CheckContentType(desired_read_type, prim_index.dtype().lanes());
spirv::SType content_type = builder_->GetSType(info.element_type);
spirv::Value buffer = MakeValue(buffer_var);
spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
mask |= spv::MemoryAccessVolatileMask;
}
if (desired_read_type == info.element_type) {
// Requested a single value from an array. This may be a scalar load
// or a vectorized load, based on the array element type.
PrimExpr vec_index = analyzer_->Simplify(prim_index);
spirv::Value index = MakeValue(vec_index);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
// OpTypeBool have no physical address/storage. Here, cast from
// the storage type to an OpTypeBool.
if (op->dtype == DataType::Bool()) {
auto spirv_bool = builder_->GetSType(DataType::Bool());
loaded = builder_->Cast(spirv_bool, loaded);
}
return loaded;
} else if (desired_read_type.element_of() == info.element_type) {
// Requested several elements returned as an array. Read out each
// element and concatenate into the result.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
};
this->Scalarize(prim_index, f);
return builder_->Concat(values);
} else {
TVM_FFI_THROW(InternalError) << "Cannot perform buffer access of buffer variable '"
<< buffer_var->name_hint << "' with element type "
<< info.element_type << " using index of type "
<< prim_index->dtype << " to produce output of type " << op->dtype;
return spirv::Value();
}
}
void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f) {
if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
PrimExpr offset = ramp->base + ramp->stride * i;
f(i, MakeValue(offset));
}
} else {
spirv::SType etype = builder_->GetSType(e.dtype().element_of());
spirv::Value value = MakeValue(e);
for (int i = 0; i < e.dtype().lanes(); ++i) {
f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i));
}
}
}
spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) {
TVM_FFI_ICHECK(op->vectors.size() == 1 && op->indices.size() == 1)
<< "SPIR-V codegen only supports shuffle "
<< "of one vector with one index";
spirv::Value vector = MakeValue(op->vectors[0]);
int index = Downcast<Integer>(op->indices[0])->value;
spirv::SType etype = builder_->GetSType(op->dtype);
spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index);
return element;
}
void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers";
TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported.";
Var buffer_var = op->buffer->data;
PrimExpr prim_index = op->indices[0];
auto it = storage_info_.find(buffer_var.get());
TVM_FFI_ICHECK(it != storage_info_.end());
StorageInfo& info = it->second;
info.CheckContentType(op->value.dtype(), prim_index.dtype().lanes());
spirv::SType content_type = builder_->GetSType(info.element_type);
spirv::Value buffer = MakeValue(buffer_var);
spirv::Value value = MakeValue(op->value);
spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
mask |= spv::MemoryAccessVolatileMask;
}
if (op->value.dtype() == info.element_type) {
// Requested store of a single value. This may be a scalar store
// or a vectorized store, based on the array element type.
TVM_FFI_ICHECK_EQ(info.element_type, op->value.dtype())
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(prim_index);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, value, mask);
} else if (op->value.dtype().element_of() == info.element_type) {
// Requested store of several arbitrarily located values. Extract
// each value from the composite, then assign to the buffer.
auto f = [&](int i, spirv::Value index) {
spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
};
this->Scalarize(prim_index, f);
} else {
TVM_FFI_THROW(InternalError) << "Cannot store value of type " << op->value.dtype()
<< " into buffer variable '" << buffer_var->name_hint
<< "' with element type " << info.element_type
<< " using index of type " << prim_index->dtype;
}
}
void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
spirv::Value init_value = MakeValue(op->min);
PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent);
spirv::Value end_value = MakeValue(end);
// loop step
spirv::Value step;
if (op->HasTrivialStep()) {
step = op->loop_var.dtype().is_int() ? builder_->IntImm(init_value.stype, 1)
: builder_->UIntImm(init_value.stype, 1);
} else {
step = MakeValue(tvm::cast(end->dtype, *op->step));
}
// Must get init label after making value(to make sure they are correct)
spirv::Label init_label = builder_->CurrentLabel();
spirv::Label head_label = builder_->NewLabel();
builder_->SetName(head_label, "for_loop_head");
spirv::Label body_label = builder_->NewLabel();
builder_->SetName(body_label, "for_loop_body");
spirv::Label continue_label = builder_->NewLabel();
builder_->SetName(continue_label, "for_loop_continue");
spirv::Label merge_label = builder_->NewLabel();
builder_->SetName(merge_label, "for_loop_merge");
builder_->MakeInst(spv::OpBranch, head_label);
// Loop head - Phi must be created AFTER StartLabel so it's in the head block
builder_->StartLabel(head_label);
spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
loop_var.SetIncoming(0, init_value, init_label);
spirv::Value loop_cond = builder_->LT(loop_var, end_value);
uint32_t control =
(op->kind == ForKind::kUnrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
weight_likely_branch_, 1);
// loop body
builder_->StartLabel(body_label);
var_map_[op->loop_var.get()] = spirv::Value(loop_var);
this->VisitStmt(op->body);
builder_->MakeInst(spv::OpBranch, continue_label);
// loop continue
builder_->StartLabel(continue_label);
spirv::Value next_value = builder_->Add(loop_var, step);
loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
builder_->MakeInst(spv::OpBranch, head_label);
// loop merge
builder_->StartLabel(merge_label);
}
void CodeGenSPIRV::VisitStmt_(const WhileNode* op) {
spirv::Label head_label = builder_->NewLabel();
spirv::Label condition_label = builder_->NewLabel();
spirv::Label body_label = builder_->NewLabel();
spirv::Label continue_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
builder_->MakeInst(spv::OpBranch, head_label);
// Loop head
builder_->StartLabel(head_label);
uint32_t control = spv::LoopControlMaskNone;
builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
builder_->MakeInst(spv::OpBranch, condition_label);
// Loop condition evaluation. The condition could contain if/else
// blocks that introduce additional labels, so the condition cannot
// be in the loop head's block.
builder_->StartLabel(condition_label);
spirv::Value loop_cond = MakeValue(op->condition);
builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
weight_likely_branch_, 1);
// loop body
builder_->StartLabel(body_label);
this->VisitStmt(op->body);
builder_->MakeInst(spv::OpBranch, continue_label);
// loop continue
builder_->StartLabel(continue_label);
builder_->MakeInst(spv::OpBranch, head_label);
// loop merge
builder_->StartLabel(merge_label);
}
void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
spirv::Value cond = MakeValue(op->condition);
spirv::Label then_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
if (op->else_case) {
spirv::Label else_label = builder_->NewLabel();
builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
// then block
builder_->StartLabel(then_label);
this->VisitStmt(op->then_case);
builder_->MakeInst(spv::OpBranch, merge_label);
// else block
builder_->StartLabel(else_label);
this->VisitStmt(op->else_case.value());
builder_->MakeInst(spv::OpBranch, merge_label);
} else {
builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(spv::OpBranchConditional, cond, then_label, merge_label,
weight_likely_branch_, 1);
// then block
builder_->StartLabel(then_label);
this->VisitStmt(op->then_case);
builder_->MakeInst(spv::OpBranch, merge_label);
}
// start merge label;
builder_->StartLabel(merge_label);
}
void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
TVM_FFI_ICHECK(!is_zero(op->condition));
TVM_FFI_ICHECK(!op->dtype.is_handle());
size_t constant_size = op->ConstantAllocationSize();
TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
spirv::Value buf;
const std::string scope = GetPtrStorageScope(op->buffer_var);
auto storage_scope = runtime::StorageScope::Create(scope);
spirv::SType etype = builder_->GetSType(op->dtype);
runtime::StorageRank rank = storage_scope.rank;
spv::StorageClass storage_class;
const VarNode* var_node = (op->buffer_var).get();
switch (rank) {
case runtime::StorageRank::kWMMAMatrixA:
case runtime::StorageRank::kWMMAMatrixB:
case runtime::StorageRank::kWMMAAccumulator: {
TVM_FFI_ICHECK(fragment_info_.count(var_node));
fragment_info_[var_node].scope = scope;
etype = GetFragmentSType(var_node, op->dtype);
storage_class = spv::StorageClassFunction;
fragment_info_[var_node].sclass = storage_class;
TVM_FFI_ICHECK(fragment_info_.count(var_node));
const std::string& scope = fragment_info_[var_node].scope;
const std::string& shape_str = fragment_info_.at(var_node).shape;
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
int64_t size = dim.first * dim.second;
buf = builder_->Allocate(etype, static_cast<uint32_t>(constant_size) / size, storage_class);
} break;
case runtime::StorageRank::kLocal: {
storage_class = spv::StorageClassFunction;
buf = builder_->Allocate(etype, static_cast<uint32_t>(constant_size), storage_class);
} break;
case runtime::StorageRank::kShared: {
storage_class = spv::StorageClassWorkgroup;
// Shared memory
// Aligned on 4-byte boundary
int32_t aligned_constant_size = ((constant_size + 3) & ~0x3);
buf = builder_->Allocate(etype, static_cast<uint32_t>(aligned_constant_size), storage_class);
size_t num_bytes =
op->dtype.bytes() * op->dtype.lanes() * static_cast<uint32_t>(aligned_constant_size);
shared_memory_bytes_used_ += num_bytes;
} break;
default:
TVM_FFI_THROW(InternalError) << "Can only allocate shared or local memory inside kernel";
}
builder_->SetName(buf, op->buffer_var->name_hint);
StorageInfo& info = storage_info_[op->buffer_var.get()];
TVM_FFI_ICHECK(!info.element_type_known);
info.SetContentType(op->dtype, op->buffer_var->name_hint);
TVM_FFI_ICHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
}
void CodeGenSPIRV::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); }
void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
// Will throw error if rebinding same local variable to a different extent.
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
}
}
} else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>();
TVM_FFI_ICHECK(v);
storage_info_[v].is_volatile = true;
} else if (op->attr_key == tir::attr::buffer_bind_scope) {
const VarNode* v = op->node.as<VarNode>();
TVM_FFI_ICHECK(v);
} else if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_info_[buffer] = {shape_str->value};
}
this->VisitStmt(op->body);
}
void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
// AssertStmt is a leaf — no body to visit.
}
void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
TVM_FFI_ICHECK(!var_map_.count(op->var.get()));
TVM_FFI_ICHECK(!op->var.dtype().is_handle());
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
this->VisitStmt(op->body);
}
void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
this->VisitStmt(stmt);
}
}
void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataType& dtype) {
TVM_FFI_ICHECK(fragment_info_.count(buffer));
const std::string& scope = fragment_info_[buffer].scope;
const std::string& shape_str = fragment_info_.at(buffer).shape;
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
int64_t size = dim.first * dim.second;
spirv::SType stype = builder_->GetSType(dtype.with_lanes(size), dim.first, dim.second);
fragment_info_[buffer].stype = stype;
return stype;
}
DataType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) {
auto it = storage_info_.find(buffer);
TVM_FFI_ICHECK(it != storage_info_.end());
return it->second.element_type;
}
} // namespace codegen
} // namespace tvm