blob: 5629d4f7f216c87ff44e050db1b52150aa90991f [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file spirv_support
*
* \brief Utility for determining which spirv capabilities a TVM
* target supports.
*/
#include "spirv_support.h"
#include <spirv.hpp>
namespace tvm {
namespace codegen {
SPIRVSupport::SPIRVSupport(tvm::Target target) {
auto device_type = target->GetTargetDeviceType();
TVM_FFI_ICHECK(device_type == kDLVulkan || device_type == kDLOpenCL || device_type == kDLWebGPU)
<< "Unsupported device type for SPIRV codegen:" << device_type;
if (target->GetAttr<Integer>("vulkan_api_version")) {
vulkan_api_version = target->GetAttr<Integer>("vulkan_api_version").value().IntValue();
}
if (target->GetAttr<Integer>("supported_subgroup_operations")) {
supported_subgroup_operations =
target->GetAttr<Integer>("supported_subgroup_operations").value().IntValue();
}
if (target->GetAttr<Integer>("max_push_constants_size")) {
max_push_constants_size =
target->GetAttr<Integer>("max_push_constants_size").value().IntValue();
}
if (target->GetAttr<Integer>("max_uniform_buffer_range")) {
max_uniform_buffer_range =
target->GetAttr<Integer>("max_uniform_buffer_range").value().IntValue();
}
if (target->GetAttr<Integer>("max_storage_buffer_range")) {
max_storage_buffer_range =
target->GetAttr<Integer>("max_storage_buffer_range").value().IntValue();
}
if (target->GetAttr<Integer>("max_shared_memory_per_block")) {
max_shared_memory_per_block =
target->GetAttr<Integer>("max_shared_memory_per_block").value().IntValue();
}
if (target->GetAttr<Integer>("max_per_stage_descriptor_storage_buffer")) {
max_per_stage_descriptor_storage_buffers =
target->GetAttr<Integer>("max_per_stage_descriptor_storage_buffer").value().IntValue();
}
if (target->GetAttr<Bool>("supports_storage_buffer_storage_class")) {
supports_storage_buffer_storage_class =
target->GetAttr<Bool>("supports_storage_buffer_storage_class").value();
}
if (target->GetAttr<Bool>("supports_8bit_buffer")) {
supports_storage_buffer_8bit_access = target->GetAttr<Bool>("supports_8bit_buffer").value();
}
if (target->GetAttr<Bool>("supports_16bit_buffer")) {
supports_storage_buffer_16bit_access = target->GetAttr<Bool>("supports_16bit_buffer").value();
}
if (target->GetAttr<Bool>("supports_float16")) {
supports_float16 = target->GetAttr<Bool>("supports_float16").value();
}
if (target->GetAttr<Bool>("supports_float64")) {
supports_float64 = target->GetAttr<Bool>("supports_float64").value();
}
if (target->GetAttr<Bool>("supports_int8")) {
supports_int8 = target->GetAttr<Bool>("supports_int8").value();
}
if (target->GetAttr<Bool>("supports_int16")) {
supports_int16 = target->GetAttr<Bool>("supports_int16").value();
}
if (target->GetAttr<Bool>("supports_int64")) {
supports_int64 = target->GetAttr<Bool>("supports_int64").value();
}
// Check whether integer dot product is enabled in the target string.
if (target->GetAttr<Bool>("supports_integer_dot_product")) {
supports_integer_dot_product = target->GetAttr<Bool>("supports_integer_dot_product").value();
}
// Check whether integer dot product is enabled in mattr.
if (const ffi::Optional<ffi::Array<ffi::String>>& v =
target->GetAttr<ffi::Array<ffi::String>>("mattr")) {
for (const ffi::String& s : v.value()) {
if (s.compare("+dotprod") == 0) {
supports_integer_dot_product = true;
break;
}
}
}
// Check whether cooperative matrix is enabled in the target string.
if (target->GetAttr<Bool>("supports_cooperative_matrix")) {
supports_cooperative_matrix = target->GetAttr<Bool>("supports_cooperative_matrix").value();
}
}
} // namespace codegen
} // namespace tvm