blob: cc39972432a38c09ee8ad521b8bf3442f4ddfea2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "vulkan_device.h"
#include <algorithm>
#include <mutex>
#include <unordered_map>
#include <utility>
#include "../../support/utils.h"
#include "vulkan_common.h"
#include "vulkan_device.h"
#include "vulkan_device_api.h"
#include "vulkan_instance.h"
namespace tvm {
namespace runtime {
namespace vulkan {
VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance,
const VulkanDevice& device) {
///////////////////////////////////////////////////////////////
// Query properties from Vulkan API //
///////////////////////////////////////////////////////////////
// Declare output locations for properties
VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2};
VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES};
VkPhysicalDeviceSubgroupProperties subgroup = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES};
// Need to do initial query in order to check the apiVersion.
vkGetPhysicalDeviceProperties(device, &properties.properties);
// Set up linked list for property query
{
void** pp_next = &properties.pNext;
if (device.HasExtension("VK_KHR_driver_properties")) {
*pp_next = &driver;
pp_next = &driver.pNext;
}
if (properties.properties.apiVersion >= VK_API_VERSION_1_1) {
*pp_next = &subgroup;
pp_next = &subgroup.pNext;
}
}
// Declare output locations for features
VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
VkPhysicalDevice8BitStorageFeatures storage_8bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES};
VkPhysicalDevice16BitStorageFeatures storage_16bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES};
VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES};
// Set up linked list for feature query
{
void** pp_next = &features.pNext;
if (device.HasExtension("VK_KHR_8bit_storage")) {
*pp_next = &storage_8bit;
pp_next = &storage_8bit.pNext;
}
if (device.HasExtension("VK_KHR_16bit_storage")) {
*pp_next = &storage_16bit;
pp_next = &storage_16bit.pNext;
}
if (device.HasExtension("VK_KHR_shader_float16_int8")) {
*pp_next = &float16_int8;
pp_next = &float16_int8.pNext;
}
}
if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) {
// Preferred method, call to get all properties that can be queried.
auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR"));
vkGetPhysicalDeviceProperties2KHR(device, &properties);
auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR"));
vkGetPhysicalDeviceFeatures2KHR(device, &features);
} else {
// Fallback, get as many features as we can from the Vulkan1.0
// API. Corresponding vkGetPhysicalDeviceProperties was already done earlier.
vkGetPhysicalDeviceFeatures(device, &features.features);
}
///////////////////////////////////////////////////////////////
// Fill member variables from Vulkan structures //
///////////////////////////////////////////////////////////////
supports_float16 = float16_int8.shaderFloat16;
supports_float32 = true;
supports_float64 = features.features.shaderFloat64;
supports_int8 = float16_int8.shaderInt8;
supports_int16 = features.features.shaderInt16;
supports_int32 = true;
supports_int64 = features.features.shaderInt64;
supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess;
supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess;
supports_storage_buffer_storage_class =
device.HasExtension("VK_KHR_storage_buffer_storage_class");
// Support is available based on these extensions, but allow it to
// be disabled based on an environment variable.
supports_push_descriptor = device.HasExtension("VK_KHR_push_descriptor") &&
device.HasExtension("VK_KHR_descriptor_update_template") &&
!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR");
// Support is available based on these extensions, but allow it to
// be disabled based on an environment variable.
supports_dedicated_allocation =
device.HasExtension("VK_KHR_get_memory_requirements2") &&
device.HasExtension("VK_KHR_dedicated_allocation") &&
!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION");
supports_integer_dot_product = device.HasExtension("VK_KHR_shader_integer_dot_product");
supports_cooperative_matrix = device.HasExtension("VK_NV_cooperative_matrix");
// The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically
// needed, since it will be set so long at least one queue has
// VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future
// confusion..
supported_subgroup_operations =
(subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0;
max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations;
// Even if we can't query it, warp size must be at least 1.
// thread_warp_size = std::max(subgroup.subgroupSize, 1U);
// vulkan's subgroup may not directly map to warp and atm
// can cause issues in softmax allreduce in NVidia GPU
// disable warp setting to be safe.
thread_warp_size = 1U;
max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0];
max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1];
max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2];
max_push_constants_size = properties.properties.limits.maxPushConstantsSize;
max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange;
max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange;
max_per_stage_descriptor_storage_buffer =
properties.properties.limits.maxPerStageDescriptorStorageBuffers;
max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize;
device_name = properties.properties.deviceName;
driver_version = properties.properties.driverVersion;
if (device.HasExtension("VK_KHR_driver_properties")) {
driver_name = driver.driverName;
}
switch (properties.properties.deviceType) {
case VK_PHYSICAL_DEVICE_TYPE_OTHER:
device_type = "other";
break;
case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
device_type = "integrated";
break;
case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
device_type = "discrete";
break;
case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
device_type = "virtual";
break;
case VK_PHYSICAL_DEVICE_TYPE_CPU:
device_type = "cpu";
break;
default:
LOG(FATAL) << "Unknown vulkan device type: " << properties.properties.deviceType;
break;
}
// By default, use the maximum API version that the driver allows,
// so that any supported features can be used by TVM shaders.
// However, if we can query the conformance version, then limit to
// only using the api version that passes the vulkan conformance
// tests.
vulkan_api_version = properties.properties.apiVersion;
if (device.HasExtension("VK_KHR_driver_properties")) {
auto api_major = VK_VERSION_MAJOR(vulkan_api_version);
auto api_minor = VK_VERSION_MINOR(vulkan_api_version);
if ((api_major > driver.conformanceVersion.major) ||
((api_major == driver.conformanceVersion.major) &&
(api_minor > driver.conformanceVersion.minor))) {
vulkan_api_version =
VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0);
}
}
// From "Versions and Formats" section of Vulkan spec.
max_spirv_version = 0x10000;
if (vulkan_api_version >= VK_API_VERSION_1_2) {
max_spirv_version = 0x10500;
} else if (device.HasExtension("VK_KHR_spirv_1_4")) {
max_spirv_version = 0x10400;
} else if (vulkan_api_version >= VK_API_VERSION_1_1) {
max_spirv_version = 0x10300;
}
}
VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) {
vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL(
vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR"));
vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL(
vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR"));
vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL(
vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR"));
vkCmdPushDescriptorSetWithTemplateKHR = (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL(
vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR"));
}
VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2Functions(
VkDevice device) {
vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL(
vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR"));
}
VulkanQueueInsertDebugUtilsLabelFunctions::VulkanQueueInsertDebugUtilsLabelFunctions(
VkInstance instance) {
vkQueueInsertDebugUtilsLabelEXT = (PFN_vkQueueInsertDebugUtilsLabelEXT)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkQueueInsertDebugUtilsLabelEXT"));
}
VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_device)
: physical_device_(phy_device) {
queue_family_index = SelectComputeQueueFamily();
if (queue_family_index == uint32_t(-1)) {
// The GPU doesn't support compute, cannot use
return;
}
enabled_extensions = SelectEnabledExtensions();
device_properties = VulkanDeviceProperties(instance, *this);
CreateVkDevice(instance);
// Currently, any exceptions called after this point will prevent
// vkDestroyDevice from being called in the destructor. If this
// becomes an issue, can split out the VulkanDevice into two
// classes, one of which strictly holds the VkDevice, and one which
// holds the ancillary handles that TVM needs.
vkGetDeviceQueue(device_, queue_family_index, 0, &queue);
// Find suitable memory type for staging and compute
// Find suitable compute index.
VkBuffer buffer;
VkMemoryRequirements req_staging, req_compute;
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = 1024;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &queue_family_index;
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
// get staging requirement
info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(device_, buffer, &req_staging);
vkDestroyBuffer(device_, buffer, nullptr);
// get compute requirement
info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(device_, buffer, &req_compute);
vkDestroyBuffer(device_, buffer, nullptr);
// Query phyiscal device property
// find a memory that is host visible, no need to be consistent
int win_rank = -1;
VkPhysicalDeviceMemoryProperties prop;
vkGetPhysicalDeviceMemoryProperties(physical_device_, &prop);
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
int64_t heap_size = static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
if (rank > win_rank) {
win_rank = rank;
staging_mtype_index = k;
coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
}
}
ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
win_rank = -1;
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
int64_t heap_size = static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
// prefer not host visible
rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
if (rank > win_rank) {
win_rank = rank;
compute_mtype_index = k;
compute_memory_size = heap_size;
}
}
ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
if (device_properties.supports_push_descriptor) {
descriptor_template_khr_functions =
std::make_unique<VulkanDescriptorTemplateKHRFunctions>(device_);
}
if (device_properties.supports_dedicated_allocation) {
get_buffer_memory_requirements_2_functions =
std::make_unique<VulkanGetBufferMemoryRequirements2Functions>(device_);
}
if (instance.HasExtension("VK_EXT_debug_utils")) {
queue_insert_debug_utils_label_functions =
std::make_unique<VulkanQueueInsertDebugUtilsLabelFunctions>(instance);
}
}
VulkanDevice::~VulkanDevice() {
// Need to clear anything that uses this device calling
// vkDestroyDevice. Might be a sign that the VkDevice should be
// held by member variable rather than beind owned directly by
// VulkanDevice.
stream_per_thread.Clear();
staging_buffer_per_thread.Clear();
uniform_buffer_per_thread.Clear();
if (device_) {
vkDestroyDevice(device_, nullptr);
}
}
VulkanDevice::VulkanDevice(VulkanDevice&& other) { do_swap(std::move(other)); }
VulkanDevice& VulkanDevice::operator=(VulkanDevice&& other) {
do_swap(std::move(other));
return *this;
}
void VulkanDevice::do_swap(VulkanDevice&& other) {
if (this == &other) {
return;
}
std::lock(queue_mutex, other.queue_mutex);
std::lock_guard<std::mutex> lock_self(queue_mutex, std::adopt_lock);
std::lock_guard<std::mutex> lock_other(other.queue_mutex, std::adopt_lock);
std::swap(device_properties, other.device_properties);
std::swap(staging_mtype_index, other.staging_mtype_index);
std::swap(coherent_staging, other.coherent_staging);
std::swap(descriptor_template_khr_functions, other.descriptor_template_khr_functions);
std::swap(get_buffer_memory_requirements_2_functions,
other.get_buffer_memory_requirements_2_functions);
std::swap(queue_insert_debug_utils_label_functions,
other.queue_insert_debug_utils_label_functions);
std::swap(compute_mtype_index, other.compute_mtype_index);
std::swap(compute_memory_size, other.compute_memory_size);
std::swap(queue, other.queue);
std::swap(queue_family_index, other.queue_family_index);
std::swap(physical_device_, other.physical_device_);
std::swap(enabled_extensions, other.enabled_extensions);
std::swap(device_, other.device_);
}
bool VulkanDevice::SupportsCompute() const { return queue_family_index != uint32_t(-1); }
void VulkanDevice::QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const {
// Multiple streams (on different threads) use the same VulkanDevice
// instance, so we need to externally synchronize accesses.
std::lock_guard<std::mutex> lock(queue_mutex);
VULKAN_CALL(vkQueueSubmit(queue, 1, &submit_info, fence));
}
uint32_t VulkanDevice::SelectComputeQueueFamily() const {
// Get a queue family that supports compute. We currently only use
// one queue from one family.
uint32_t queue_prop_count = 0;
vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, nullptr);
std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count,
dmlc::BeginPtr(queue_props));
std::vector<uint32_t> result;
// Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using
// compute-only queues gives better responsiveness for other graphics workload (e.g. desktop).
for (uint32_t i = 0; i != queue_prop_count; ++i) {
if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 &&
(VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) {
return i;
}
}
// Now, push the compute queues that we skipped above into the list.
for (uint32_t i = 0; i != queue_prop_count; ++i) {
if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 &&
(VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) {
return i;
}
}
// No queues support compute capability, this GPU cannot be used.
return -1;
}
std::vector<const char*> VulkanDevice::SelectEnabledExtensions() const {
std::vector<const char*> required_extensions{};
std::vector<const char*> optional_extensions{"VK_KHR_driver_properties",
"VK_KHR_storage_buffer_storage_class",
"VK_KHR_8bit_storage",
"VK_KHR_16bit_storage",
"VK_KHR_shader_float16_int8",
"VK_KHR_push_descriptor",
"VK_KHR_descriptor_update_template",
"VK_KHR_get_memory_requirements2",
"VK_KHR_dedicated_allocation",
"VK_KHR_spirv_1_4",
"VK_KHR_shader_integer_dot_product",
"VK_NV_cooperative_matrix"};
uint32_t device_extension_prop_count;
VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr,
&device_extension_prop_count, nullptr));
std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
physical_device_, nullptr, &device_extension_prop_count, device_extension_prop.data()));
return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions);
}
bool VulkanDevice::HasExtension(const char* query) const {
return std::any_of(enabled_extensions.begin(), enabled_extensions.end(),
[&](const char* extension) { return std::strcmp(query, extension) == 0; });
}
void VulkanDevice::CreateVkDevice(const VulkanInstance& instance) {
// Enable all features we may use that a device supports.
VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
VkPhysicalDevice8BitStorageFeatures storage_8bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES};
VkPhysicalDevice16BitStorageFeatures storage_16bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES};
VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES};
void** pp_next = &enabled_features.pNext;
bool needs_float16_int8 = false;
if (device_properties.supports_float16) {
float16_int8.shaderFloat16 = true;
needs_float16_int8 = true;
}
if (device_properties.supports_float64) {
enabled_features.features.shaderFloat64 = true;
}
if (device_properties.supports_int8) {
float16_int8.shaderInt8 = true;
needs_float16_int8 = true;
}
if (device_properties.supports_int16) {
enabled_features.features.shaderInt16 = true;
}
if (device_properties.supports_int64) {
enabled_features.features.shaderInt64 = true;
}
if (device_properties.supports_8bit_buffer) {
storage_8bit.storageBuffer8BitAccess = true;
*pp_next = &storage_8bit;
pp_next = &storage_8bit.pNext;
}
if (device_properties.supports_16bit_buffer) {
storage_16bit.storageBuffer16BitAccess = true;
*pp_next = &storage_16bit;
pp_next = &storage_16bit.pNext;
}
if (needs_float16_int8) {
*pp_next = &float16_int8;
pp_next = &float16_int8.pNext;
}
float priority = 1.0f;
struct VkDeviceQueueCreateInfo queue_create_info;
queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
queue_create_info.pNext = nullptr;
queue_create_info.flags = 0;
queue_create_info.queueFamilyIndex = queue_family_index;
queue_create_info.queueCount = 1;
queue_create_info.pQueuePriorities = &priority;
VkDeviceCreateInfo device_create_info;
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.pNext = nullptr;
device_create_info.flags = 0;
device_create_info.queueCreateInfoCount = 1;
device_create_info.pQueueCreateInfos = &queue_create_info;
device_create_info.enabledLayerCount = 0;
device_create_info.ppEnabledLayerNames = nullptr;
device_create_info.enabledExtensionCount = enabled_extensions.size();
device_create_info.ppEnabledExtensionNames = enabled_extensions.data();
if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) {
device_create_info.pEnabledFeatures = nullptr;
device_create_info.pNext = &enabled_features;
} else {
device_create_info.pNext = nullptr;
device_create_info.pEnabledFeatures = &enabled_features.features;
}
VULKAN_CALL(vkCreateDevice(physical_device_, &device_create_info, nullptr, &device_));
}
VulkanStream& VulkanDevice::ThreadLocalStream() {
return const_cast<VulkanStream&>(const_cast<const VulkanDevice*>(this)->ThreadLocalStream());
}
const VulkanStream& VulkanDevice::ThreadLocalStream() const {
return stream_per_thread.GetOrMake(this);
}
VulkanStagingBuffer& VulkanDevice::ThreadLocalStagingBuffer(size_t min_size) {
auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
VulkanStagingBuffer& result =
staging_buffer_per_thread.GetOrMake(*this, min_size, usage, staging_mtype_index);
if (result.size < min_size) {
result = VulkanStagingBuffer(*this, min_size, usage, staging_mtype_index);
}
return result;
}
void VulkanDevice::AllocateThreadLocalUniformBuffer(size_t min_size) {
auto usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
auto buffer_info = MakeBufferCreateInfo(min_size, usage);
auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
auto mem_type_index = FindMemoryType(*this, buffer_info, prop);
VulkanUniformBuffer& result =
uniform_buffer_per_thread.GetOrMake(*this, min_size, usage, mem_type_index);
if (result.size < min_size) {
result = VulkanUniformBuffer(*this, min_size, usage, mem_type_index);
}
}
VulkanStagingBuffer& VulkanDevice::ThreadLocalUniformBuffer(size_t min_size) {
VulkanStagingBuffer* buffer = uniform_buffer_per_thread.Get();
ICHECK(buffer) << "Vulkan uniform buffer requested, but not previously allocated.";
ICHECK_GE(buffer->size, min_size)
<< "Vulkan uniform buffer of size " << min_size << " requested, but only " << buffer->size
<< " was previously allocated.";
return *buffer;
}
uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info,
VkMemoryPropertyFlags req_prop) {
VkBuffer buffer;
VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer));
VkMemoryRequirements mem_reqs;
vkGetBufferMemoryRequirements(device, buffer, &mem_reqs);
uint32_t type_bits = mem_reqs.memoryTypeBits;
VkPhysicalDeviceMemoryProperties phy_mem_prop;
vkGetPhysicalDeviceMemoryProperties(device, &phy_mem_prop);
for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) {
if ((type_bits & 1) == 1 &&
(phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) {
return i;
}
type_bits >>= 1;
}
LOG(FATAL) << "Requested memory type not found";
}
VulkanHostVisibleBuffer* GetOrAllocate(
int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index,
std::unordered_map<size_t, std::unique_ptr<VulkanHostVisibleBuffer>>* buffers_ptr,
bool sync_before_realloc) {
auto& device = VulkanDeviceAPI::Global()->device(device_id);
auto& buffers = *buffers_ptr;
bool needs_alloc = !buffers[device_id] || (buffers[device_id]->size < size);
bool is_realloc = buffers[device_id] && (buffers[device_id]->size < size);
if (is_realloc && sync_before_realloc) {
device.ThreadLocalStream().Synchronize();
}
if (needs_alloc) {
auto new_buffer =
std::make_unique<VulkanHostVisibleBuffer>(device, size, usage, mem_type_index);
buffers[device_id] = std::move(new_buffer);
}
return buffers[device_id].get();
}
} // namespace vulkan
} // namespace runtime
} // namespace tvm