blob: f912e482761ca1184f247384543f3b24b6874833 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ir_builder.cc
* \brief IRBuilder for SPIRV block
*/
#include "ir_builder.h"
#include <spirv.hpp>
namespace tvm {
namespace codegen {
namespace spirv {
// implementations
IRBuilder::IRBuilder(const SPIRVSupport& support) : spirv_support_(support) {}
void IRBuilder::InitHeader() {
TVM_FFI_ICHECK_EQ(header_.size(), 0U);
header_.push_back(spv::MagicNumber);
// Target SPIR-V version 1.0. Additional functionality will be
// enabled through extensions.
header_.push_back(0x10000);
// generator: set to 0, unknown
header_.push_back(0U);
// Bound: set during Finalize
header_.push_back(0U);
// Schema: reserved
header_.push_back(0U);
// Declare CapabilityShader by default. All other capabilities are
// determined by the types declared.
capabilities_used_.insert(spv::CapabilityShader);
#ifdef TVM_SPIRV_KHR_INTEGER_DOT_PRODUCT
if (spirv_support_.supports_integer_dot_product) {
capabilities_used_.insert(spv::CapabilityDotProductKHR);
capabilities_used_.insert(spv::CapabilityDotProductInput4x8BitPackedKHR);
extensions_used_.insert("SPV_KHR_integer_dot_product");
}
#endif
if (spirv_support_.supports_cooperative_matrix) {
capabilities_used_.insert(spv::CapabilityCooperativeMatrixNV);
extensions_used_.insert("SPV_NV_cooperative_matrix");
}
// memory model
ib_.Begin(spv::OpMemoryModel)
.AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450)
.Commit(&entry_);
this->InitPreDefs();
}
void IRBuilder::InitPreDefs() {
ext_glsl450_ = ExtInstImport("GLSL.std.450");
t_int32_ = DeclareType(DataType::Int(32));
t_uint32_ = DeclareType(DataType::UInt(32));
t_bool_ = DeclareType(DataType::Bool());
t_fp32_ = DeclareType(DataType::Float(32));
const_i32_zero_ = IntImm(t_int32_, 0);
// declare void, and void functions
t_void_.id = id_counter_++;
ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_);
t_void_func_.id = id_counter_++;
ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_);
}
std::vector<uint32_t> IRBuilder::Finalize() {
std::vector<uint32_t> data;
// Index for upper bound of id numbers.
const int kBoundLoc = 3;
header_[kBoundLoc] = id_counter_;
data.insert(data.end(), header_.begin(), header_.end());
for (const auto& capability : capabilities_used_) {
ib_.Begin(spv::OpCapability).Add(capability).Commit(&data);
}
for (const auto& ext_name : extensions_used_) {
ib_.Begin(spv::OpExtension).Add(ext_name).Commit(&data);
}
data.insert(data.end(), extended_instruction_section_.begin(),
extended_instruction_section_.end());
data.insert(data.end(), entry_.begin(), entry_.end());
data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
data.insert(data.end(), debug_.begin(), debug_.end());
data.insert(data.end(), decorate_.begin(), decorate_.end());
data.insert(data.end(), global_.begin(), global_.end());
data.insert(data.end(), func_header_.begin(), func_header_.end());
data.insert(data.end(), function_scope_vars_.begin(), function_scope_vars_.end());
data.insert(data.end(), function_.begin(), function_.end());
return data;
}
SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) {
if (dtype == DataType::Int(32)) {
return t_int32_;
} else if (dtype == DataType::Bool()) {
return t_bool_;
} else if (dtype == DataType::Float(32)) {
return t_fp32_;
} else if (dtype == DataType::UInt(32)) {
return t_uint32_;
}
uint64_t type_key;
type_key = static_cast<uint32_t>(dtype.code());
type_key |= static_cast<uint32_t>(dtype.bits()) << 8U;
if (row * col == 0) {
TVM_FFI_ICHECK((row == 0) && (col == 0));
type_key |= static_cast<uint32_t>(dtype.lanes()) << 16U;
} else {
type_key |= static_cast<uint64_t>(row) << 32U;
type_key |= static_cast<uint64_t>(col) << 40U;
}
auto it = pod_type_tbl_.find(type_key);
if (it != pod_type_tbl_.end()) {
return it->second;
}
SType t = DeclareType(dtype, row, col);
pod_type_tbl_[type_key] = t;
return t;
}
SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass storage_class) {
TVM_FFI_ICHECK_NE(storage_class, spv::StorageClassMax);
auto key = std::make_pair(value_type.id, storage_class);
auto it = pointer_type_tbl_.find(key);
if (it != pointer_type_tbl_.end()) {
return it->second;
}
SType t;
t.id = id_counter_++;
t.type = DataType::Handle();
t.element_type_id = value_type.id;
t.storage_class = storage_class;
ib_.Begin(spv::OpTypePointer).AddSeq(t, storage_class, value_type).Commit(&global_);
pointer_type_tbl_[key] = t;
return t;
}
SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems,
bool interface_block) {
auto key = std::make_tuple(value_type.id, num_elems, interface_block);
auto it = struct_array_type_tbl_.find(key);
if (it != struct_array_type_tbl_.end()) {
return it->second;
}
SType arr_type;
arr_type.id = id_counter_++;
arr_type.type = DataType::Handle();
arr_type.element_type_id = value_type.id;
if (num_elems != 0) {
Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems);
ib_.Begin(spv::OpTypeArray).AddSeq(arr_type, value_type, length).Commit(&global_);
} else {
ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_);
}
if (interface_block) {
int nbits = value_type.type.bits() * value_type.type.lanes();
TVM_FFI_ICHECK_EQ(nbits % 8, 0);
uint32_t nbytes = static_cast<uint32_t>(nbits) / 8;
// Explicit layout is required for descriptor-backed interface blocks.
this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes);
}
// declare struct of array
SType struct_type;
struct_type.id = id_counter_++;
struct_type.type = DataType::Handle();
struct_type.element_type_id = value_type.id;
ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_);
if (interface_block) {
ib_.Begin(spv::OpMemberDecorate)
.AddSeq(struct_type, 0, spv::DecorationOffset, 0)
.Commit(&decorate_);
// Runtime array are always decorated as Block or BufferBlock
// (shader storage buffer)
if (spirv_support_.supports_storage_buffer_storage_class) {
// If SPIRV 1.3+, or with extension
// SPV_KHR_storage_buffer_storage_class, BufferBlock is
// deprecated.
extensions_used_.insert("SPV_KHR_storage_buffer_storage_class");
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
} else {
if (num_elems == 0) {
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
}
}
}
struct_array_type_tbl_[key] = struct_type;
return struct_type;
}
Value IRBuilder::StructArrayAccess(const SType& res_type, Value buffer, Value index) {
TVM_FFI_ICHECK(buffer.flag == kStructArrayPtr);
return MakeValue(spv::OpInBoundsAccessChain, res_type, buffer, const_i32_zero_, index);
}
Value IRBuilder::IntImm(const SType& dtype, int64_t value) {
return GetConst_(dtype, reinterpret_cast<uint64_t*>(&value));
}
Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { return GetConst_(dtype, &value); }
Value IRBuilder::FloatImm(const SType& dtype, double value) {
if (dtype.type.bits() == 64) {
return GetConst_(dtype, reinterpret_cast<uint64_t*>(&value));
} else if (dtype.type.bits() == 32) {
float fvalue = static_cast<float>(value);
uint32_t* ptr = reinterpret_cast<uint32_t*>(&fvalue);
uint64_t data = ptr[0];
return GetConst_(dtype, &data);
} else {
TVM_FFI_ICHECK_EQ(dtype.type.bits(), 16);
float fvalue = static_cast<float>(value);
uint32_t* ptr = reinterpret_cast<uint32_t*>(&fvalue);
uint64_t data = ptr[0];
if (data == 0)
return GetConst_(dtype, &data);
else
return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value));
}
}
Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set,
uint32_t binding) {
// If SPIRV 1.3+, or with extension SPV_KHR_storage_buffer_storage_class, BufferBlock is
// deprecated.
spv::StorageClass storage_class;
if (spirv_support_.supports_storage_buffer_storage_class) {
storage_class = spv::StorageClassStorageBuffer;
} else {
storage_class = spv::StorageClassUniform;
}
SType sarr_type = GetStructArrayType(value_type, 0, true);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Value val = NewValue(ptr_type, kStructArrayPtr);
ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
this->DecorateBufferArgument(val, descriptor_set, binding);
return val;
}
Value IRBuilder::DeclareStorageVariable(const std::vector<SType>& value_types,
spv::StorageClass storage_class, ValueKind kind) {
SType struct_type;
struct_type.id = id_counter_++;
struct_type.type = DataType::Handle();
ib_.Begin(spv::OpTypeStruct).Add(struct_type);
for (const SType& vtype : value_types) {
ib_.Add(vtype);
}
ib_.Commit(&global_);
uint32_t offset = 0;
for (uint32_t i = 0; i < value_types.size(); ++i) {
ib_.Begin(spv::OpMemberDecorate)
.AddSeq(struct_type, i, spv::DecorationOffset, offset)
.Commit(&decorate_);
DataType t = value_types[i].type;
uint32_t nbits = t.bits() * t.lanes();
TVM_FFI_ICHECK_EQ(nbits % 8, 0);
uint32_t bytes = (nbits / 8);
if (t.bits() == 32) {
// In our Vulkan runtime, each scalar argument always occupies 64 bit.
offset += bytes * 2;
} else {
TVM_FFI_ICHECK_EQ(t.bits(), 64);
offset += bytes;
}
}
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
SType ptr_type = GetPointerType(struct_type, storage_class);
Value val = NewValue(ptr_type, kind);
ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
return val;
}
Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
TVM_FFI_ICHECK_EQ(push_const_.id, 0);
return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr);
}
Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) {
SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant);
Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
IntImm(t_int32_, static_cast<int64_t>(index)));
return this->MakeValue(spv::OpLoad, v_type, ptr);
}
Value IRBuilder::DeclareUniformBuffer(const std::vector<SType>& value_types,
uint32_t descriptor_set, uint32_t binding) {
Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr);
this->DecorateBufferArgument(val, descriptor_set, binding);
return val;
}
void IRBuilder::DecorateBufferArgument(Value val, uint32_t descriptor_set, uint32_t binding) {
this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set);
this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
}
Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) {
SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform);
Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
IntImm(t_int32_, static_cast<int64_t>(index)));
return this->MakeValue(spv::OpLoad, v_type, ptr);
}
Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); }
void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) {
TVM_FFI_ICHECK_EQ(func.flag, kFunction);
ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name);
for (auto& it : built_in_tbl_) {
ib_.Add(it.second);
}
ib_.Commit(&entry_);
}
void IRBuilder::StartFunction(const Value& func) {
TVM_FFI_ICHECK_EQ(func.flag, kFunction);
// add function declaration to the header.
ib_.Begin(spv::OpFunction).AddSeq(t_void_, func, 0, t_void_func_).Commit(&func_header_);
spirv::Label start_label = this->NewLabel();
ib_.Begin(spv::OpLabel).AddSeq(start_label).Commit(&func_header_);
curr_label_ = start_label;
}
void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) {
TVM_FFI_ICHECK_EQ(func.flag, kFunction);
ib_.Begin(spv::OpExecutionMode)
.AddSeq(func, spv::ExecutionModeLocalSize, local_size[0], local_size[1], local_size[2])
.Commit(&exec_mode_);
}
Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems,
spv::StorageClass storage_class) {
TVM_FFI_ICHECK_NE(num_elems, 0U);
SType sarr_type = GetStructArrayType(value_type, num_elems, false);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Value val = NewValue(ptr_type, kStructArrayPtr);
if (storage_class == spv::StorageClassFunction) {
ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&func_header_);
} else {
ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
}
return val;
}
Value IRBuilder::GetWorkgroupID(uint32_t dim_index) {
std::string name = "blockIdx." + std::string(1, 'x' + dim_index);
return GetBuiltInValue(spv::BuiltInWorkgroupId, dim_index, name);
}
Value IRBuilder::GetLocalID(uint32_t dim_index) {
std::string name = "threadIdx." + std::string(1, 'x' + dim_index);
return GetBuiltInValue(spv::BuiltInLocalInvocationId, dim_index, name);
}
Value IRBuilder::GetBuiltInValue(spv::BuiltIn built_in, uint32_t index, const std::string& name) {
// Returned cached value if it exists
{
auto it = built_in_values_tbl_.find({built_in, index});
if (it != built_in_values_tbl_.end()) {
return it->second;
}
}
DataType data_type;
DataType global_arr_type;
switch (built_in) {
case spv::BuiltInLocalInvocationId:
case spv::BuiltInWorkgroupId:
data_type = DataType::Int(32);
global_arr_type = data_type.with_lanes(3);
break;
default:
TVM_FFI_THROW(InternalError) << "No data type defined for SPIR-V Built-In " << built_in;
}
// Look up the decorated array value at global scope. If it doesn't
// exist already, declare it.
Value global_array;
{
auto it = built_in_tbl_.find(built_in);
if (it != built_in_tbl_.end()) {
global_array = it->second;
} else {
SType ptr_arr_type = this->GetPointerType(GetSType(global_arr_type), spv::StorageClassInput);
global_array = NewValue(ptr_arr_type, kVectorPtr);
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_arr_type, global_array, spv::StorageClassInput)
.Commit(&global_);
this->Decorate(spv::OpDecorate, global_array, spv::DecorationBuiltIn, built_in);
switch (built_in) {
case spv::BuiltInLocalInvocationId:
SetName(global_array, "BuiltInLocalInvocationId");
break;
case spv::BuiltInWorkgroupId:
SetName(global_array, "BuiltInWorkgroupId");
break;
default:
break;
}
built_in_tbl_[built_in] = global_array;
}
}
// Declare the dereferenced value
SType data_stype = GetSType(data_type);
SType ptr_type = this->GetPointerType(data_stype, spv::StorageClassInput);
Value global_const_index = UIntImm(t_int32_, static_cast<int64_t>(index));
Value ptr = NewValue(ptr_type, kNormal);
ib_.Begin(spv::OpAccessChain)
.AddSeq(ptr_type, ptr, global_array, global_const_index)
.Commit(&function_scope_vars_);
Value output = NewValue(data_stype, kNormal);
ib_.Begin(spv::OpLoad).AddSeq(data_stype, output, ptr).Commit(&function_scope_vars_);
if (name.size()) {
SetName(output, name);
}
// Store to cache and return
built_in_values_tbl_[{built_in, index}] = output;
return output;
}
Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
auto key = std::make_pair(dtype.id, pvalue[0]);
auto it = const_tbl_.find(key);
if (it != const_tbl_.end()) {
return it->second;
}
TVM_FFI_ICHECK_LE(dtype.type.bits(), 64);
Value ret = NewValue(dtype, kConstant);
if (dtype.type == DataType::Bool()) {
// bool types.
if (*pvalue) {
ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret);
} else {
ib_.Begin(spv::OpConstantFalse).AddSeq(dtype, ret);
}
} else {
// Integral/floating-point types.
ib_.Begin(spv::OpConstant).AddSeq(dtype, ret);
uint64_t mask = 0xFFFFFFFFUL;
ib_.Add(static_cast<uint32_t>(pvalue[0] & mask));
if (dtype.type.bits() > 32) {
if (dtype.type.is_int()) {
int64_t sign_mask = 0xFFFFFFFFL;
const int64_t* sign_ptr = reinterpret_cast<const int64_t*>(pvalue);
ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
} else {
ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask));
}
}
}
ib_.Commit(&global_);
const_tbl_[key] = ret;
return ret;
}
SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) {
AddCapabilityFor(dtype);
if (dtype.lanes() == 1) {
SType t;
t.id = id_counter_++;
t.type = dtype;
if (dtype.is_bool()) {
ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_);
} else if (dtype.is_int()) {
ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_);
} else if (dtype.is_uint()) {
ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 0).Commit(&global_);
} else if (dtype.is_float()) {
ib_.Begin(spv::OpTypeFloat).AddSeq(t, dtype.bits()).Commit(&global_);
} else {
TVM_FFI_THROW(InternalError) << "declare type do not support handle";
}
return t;
} else {
SType t;
t.id = id_counter_++;
t.type = dtype;
SType base_type = GetSType(dtype.element_of());
if (row * col == 0) {
TVM_FFI_ICHECK((row == 0) && (col == 0));
ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_);
} else {
Value v_row = GetSpecConst(GetSType(DataType::UInt(32)), row);
Value v_col = GetSpecConst(GetSType(DataType::UInt(32)), col);
Value scope = UIntImm(GetSType(DataType::UInt(32)), spv::ScopeSubgroup);
ib_.Begin(spv::OpTypeCooperativeMatrixNV)
.AddSeq(t, base_type, scope, v_row, v_col)
.Commit(&global_);
}
return t;
}
}
void IRBuilder::AddCapabilityFor(const DataType& dtype) {
// Declare appropriate capabilities for int/float types
if (dtype.is_int() || dtype.is_uint()) {
if (dtype.bits() == 8) {
TVM_FFI_ICHECK(spirv_support_.supports_int8)
<< "Vulkan target does not support Int8 capability. "
<< "If your device supports 8-bit int operations, "
<< "please either add -supports_int8=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
capabilities_used_.insert(spv::CapabilityInt8);
} else if (dtype.bits() == 16) {
TVM_FFI_ICHECK(spirv_support_.supports_int16)
<< "Vulkan target does not support Int16 capability. "
<< "If your device supports 16-bit int operations, "
<< "please either add -supports_int16=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
capabilities_used_.insert(spv::CapabilityInt16);
} else if (dtype.bits() == 64) {
TVM_FFI_ICHECK(spirv_support_.supports_int64)
<< "Vulkan target does not support Int64 capability. "
<< "If your device supports 64-bit int operations, "
<< "please either add -supports_int64=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
capabilities_used_.insert(spv::CapabilityInt64);
}
} else if (dtype.is_float()) {
if (dtype.bits() == 16) {
TVM_FFI_ICHECK(spirv_support_.supports_float16)
<< "Vulkan target does not support Float16 capability. "
<< "If your device supports 16-bit float operations, "
<< "please either add -supports_float16=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
capabilities_used_.insert(spv::CapabilityFloat16);
} else if (dtype.bits() == 64) {
TVM_FFI_ICHECK(spirv_support_.supports_float64)
<< "Vulkan target does not support Float64 capability. "
<< "If your device supports 64-bit float operations, "
<< "please either add -supports_float64=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
capabilities_used_.insert(spv::CapabilityFloat64);
}
}
// Declare ability to read type to/from storage buffers. Doing so
// here is a little bit overzealous, should be relaxed in the
// future. Requiring StorageBuffer8BitAccess in order to declare an
// Int8 prevents use of an 8-bit loop iterator on a device that
// supports Int8 but doesn't support 8-bit buffer access.
if (dtype.bits() == 8 && !dtype.is_bool()) {
TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_8bit_access)
<< "Vulkan target does not support StorageBuffer8BitAccess. "
<< "If your device supports 8-bit buffer access, "
<< "please either add -supports_8bit_buffer=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
capabilities_used_.insert(spv::CapabilityStorageBuffer8BitAccess);
extensions_used_.insert("SPV_KHR_8bit_storage");
TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_storage_class)
<< "Illegal Vulkan target description. "
<< "Vulkan spec requires extension VK_KHR_storage_buffer_storage_class "
<< "if VK_KHR_8bit_storage is supported. "
<< "Please either add -supports_storage_buffer_storage_class=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
} else if (dtype.bits() == 16) {
TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_16bit_access)
<< "Vulkan target does not support StorageBuffer16BitAccess. "
<< "If your device supports 16-bit buffer access, "
<< "please either add -supports_16bit_buffer=1 to the target, "
<< "or query all device parameters by adding -from_device=0.";
extensions_used_.insert("SPV_KHR_16bit_storage");
if (spirv_support_.supports_storage_buffer_storage_class) {
capabilities_used_.insert(spv::CapabilityStorageBuffer16BitAccess);
} else {
capabilities_used_.insert(spv::CapabilityStorageUniformBufferBlock16);
}
}
}
PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) {
Value val = NewValue(out_type, kNormal);
ib_.Begin(spv::OpPhi).AddSeq(out_type, val);
for (uint32_t i = 0; i < 2 * num_incoming; ++i) {
ib_.Add(0);
}
PhiValue phi;
phi.id = val.id;
phi.stype = out_type;
phi.flag = kNormal;
phi.instr = ib_.Commit(&function_);
TVM_FFI_ICHECK_EQ(phi.instr.WordCount(), 2 * num_incoming + 3);
return phi;
}
Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id,
const std::vector<Value>& args) {
Value val = NewValue(ret_type, kNormal);
ib_.Begin(spv::OpExtInst).AddSeq(ret_type, val, ext_glsl450_, inst_id);
for (const Value& v : args) {
ib_.Add(v);
}
ib_.Commit(&function_);
return val;
}
Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vector<Value>& args,
const DataType& dtype) {
if (args.size() != 3) {
TVM_FFI_THROW(InternalError) << "Unresolved arguments in SPIRV_KHR_integer_dot_product";
}
Value val = NewValue(ret_type, kNormal);
#ifdef TVM_SPIRV_KHR_INTEGER_DOT_PRODUCT
TVM_FFI_ICHECK(spirv_support_.supports_integer_dot_product)
<< "Vulkan target does not support integer dot product capability. "
<< "If your device supports integer dot product operations, "
<< "please either add -mattr=+dotprod to the target, "
<< "or query all device parameters by adding -from_device=0.";
if (dtype.is_int()) {
ib_.Begin(spv::OpSDotAccSatKHR).AddSeq(ret_type, val);
} else if (dtype.is_uint()) {
ib_.Begin(spv::OpUDotAccSatKHR).AddSeq(ret_type, val);
} else {
TVM_FFI_THROW(InternalError) << "Unsupported type";
}
#else
TVM_FFI_THROW(InternalError)
<< "Please turn on USE_SPIRV_KHR_INTEGER_DOT_PRODUCT in config.cmake";
#endif
for (const Value& v : args) {
ib_.Add(v);
}
ib_.Commit(&function_);
return val;
}
Value IRBuilder::Concat(const std::vector<Value>& vec) {
bool is_const = vec[0].flag == kConstant;
DataType etype = vec[0].stype.type;
int lanes = etype.lanes();
for (size_t i = 1; i < vec.size(); ++i) {
TVM_FFI_ICHECK_EQ(etype, vec[i].stype.type.element_of())
<< "Cannot concat vector of different element type";
lanes += vec[i].stype.type.lanes();
is_const = is_const && (vec[i].flag == kConstant);
}
Value ret = NewValue(GetSType(etype.with_lanes(lanes)), kNormal);
if (is_const && vec.size() == static_cast<size_t>(lanes)) {
ib_.Begin(spv::OpConstantComposite);
ib_.AddSeq(ret.stype, ret);
for (const Value& v : vec) {
ib_.Add(v);
}
ib_.Commit(&global_);
} else {
ib_.Begin(spv::OpCompositeConstruct);
ib_.AddSeq(ret.stype, ret);
for (const Value& v : vec) {
ib_.Add(v);
}
ib_.Commit(&function_);
}
return ret;
}
Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
TVM_FFI_ICHECK_NE(value.stype.id, 0U);
if (value.stype.id == dst_type.id) return value;
const tvm::DataType& from = value.stype.type;
const tvm::DataType& to = dst_type.type;
TVM_FFI_ICHECK_EQ(from.lanes(), to.lanes());
if (from == DataType::Bool()) {
if (to.is_int()) {
return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0));
} else if (to.is_uint()) {
return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0));
} else if (to.is_float()) {
return MakeValue(spv::OpConvertUToF, dst_type,
Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0)));
} else {
TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to;
return Value();
}
} else if (to == DataType::Bool()) {
if (from.is_int()) {
return NE(value, IntImm(value.stype, 0));
} else if (to.is_uint()) {
return NE(value, UIntImm(value.stype, 0));
} else {
TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to;
return Value();
}
} else if (from.is_int() && to.is_int()) {
return MakeValue(spv::OpSConvert, dst_type, value);
} else if (from.is_uint() && to.is_uint()) {
return MakeValue(spv::OpUConvert, dst_type, value);
} else if (from.is_uint() && to.is_int()) {
if (from.bits() != to.bits()) {
value = MakeValue(spv::OpUConvert, GetSType(from.with_bits(to.bits())), value);
}
return MakeValue(spv::OpBitcast, dst_type, value);
} else if (from.is_int() && to.is_uint()) {
if (from.bits() != to.bits()) {
value = MakeValue(spv::OpSConvert, GetSType(from.with_bits(to.bits())), value);
}
return MakeValue(spv::OpBitcast, dst_type, value);
} else if (from.is_float() && to.is_int()) {
return MakeValue(spv::OpConvertFToS, dst_type, value);
} else if (from.is_float() && to.is_uint()) {
return MakeValue(spv::OpConvertFToU, dst_type, value);
} else if (from.is_int() && to.is_float()) {
return MakeValue(spv::OpConvertSToF, dst_type, value);
} else if (from.is_uint() && to.is_float()) {
return MakeValue(spv::OpConvertUToF, dst_type, value);
} else if (from.is_float() && to.is_float()) {
return MakeValue(spv::OpFConvert, dst_type, value);
} else {
TVM_FFI_THROW(InternalError) << "do not support type cast from " << from << " to " << to;
return Value();
}
}
Value IRBuilder::GetCompositeConst(const SType& ele_stype, const SType& composite_stype,
const double dval) {
auto key = std::make_pair(composite_stype.id, dval);
auto it = composite_const_tbl_.find(key);
if (it != composite_const_tbl_.end()) {
return it->second;
}
spirv::Value const_val = FloatImm(ele_stype, dval);
Value new_val = NewValue(composite_stype, kNormal);
ib_.Begin(spv::OpConstantComposite).AddSeq(composite_stype, new_val, const_val);
ib_.Commit(&global_);
composite_const_tbl_[key] = new_val;
return new_val;
}
Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) {
TVM_FFI_ICHECK_LE(dtype.type.bits(), 32);
Value ret = NewValue(dtype, kSpecConst);
ib_.Begin(spv::OpSpecConstant).AddSeq(dtype, ret);
ib_.Add(static_cast<uint32_t>(value));
ib_.Commit(&global_);
return ret;
}
#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI##_Op, a.stype, a, b); \
} else { \
TVM_FFI_ICHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF##_Op, a.stype, a, b); \
} \
}
#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS##_Op, a.stype, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU##_Op, a.stype, a, b); \
} else { \
TVM_FFI_ICHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF##_Op, a.stype, a, b); \
} \
}
DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add);
DEFINE_BUILDER_BINARY_USIGN_OP(Sub, Sub);
DEFINE_BUILDER_BINARY_USIGN_OP(Mul, Mul);
DEFINE_BUILDER_BINARY_SIGN_OP(Div, Div);
Value IRBuilder::Mod(Value a, Value b) {
TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id);
if (a.stype.type.is_int()) {
return MakeValue(spv::OpSRem, a.stype, a, b);
} else if (a.stype.type.is_uint()) {
return MakeValue(spv::OpUMod, a.stype, a, b);
} else {
TVM_FFI_ICHECK(a.stype.type.is_float());
return MakeValue(spv::OpFRem, a.stype, a, b);
}
}
#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \
TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS##_Op, bool_type, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU##_Op, bool_type, a, b); \
} else { \
TVM_FFI_ICHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \
}
DEFINE_BUILDER_CMP_OP(LT, LessThan);
DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \
TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI##_Op, bool_type, a, b); \
} else { \
TVM_FFI_ICHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \
}
DEFINE_BUILDER_CMP_UOP(EQ, Equal);
DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
Value IRBuilder::Select(Value cond, Value a, Value b) {
TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id);
TVM_FFI_ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool());
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
}
} // namespace spirv
} // namespace codegen
} // namespace tvm