blob: 8eabfd5c3cbf1539e543e6bdc0bd2c4b56869f64 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_
#define TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_
#include <array>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "../../../runtime/metadata.h"
#include "../../../runtime/pack_args.h"
#include "../../../runtime/thread_storage_scope.h"
#include "spirv_shader.h"
#include "vulkan/vulkan_core.h"
#include "vulkan_common.h"
#include "vulkan_device.h"
namespace tvm {
namespace runtime {
namespace vulkan {
struct VulkanPipeline {
VulkanDevice* device{nullptr};
VkShaderModule shader{VK_NULL_HANDLE};
VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
VkPipeline pipeline{VK_NULL_HANDLE};
VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
bool use_ubo{false};
};
class VulkanModuleNode;
// a wrapped function class to get packed func.
class VulkanWrappedFunc {
public:
void Init(VulkanModuleNode* m, ffi::ObjectPtr<ffi::Object> sptr, const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const ffi::Array<ffi::String>& launch_param_tags);
void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) const;
private:
// internal module
VulkanModuleNode* m_;
// the resource holder
ffi::ObjectPtr<ffi::Object> sptr_;
// v The name of the function.
std::string func_name_;
// Number of buffer arguments
size_t num_buffer_args_;
// number of packed arguments.
size_t num_pack_args_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
// Device state cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
};
class VulkanModuleNode final : public ffi::ModuleObj {
public:
explicit VulkanModuleNode(std::unordered_map<std::string, SPIRVShader> internal_smap,
ffi::Map<ffi::String, ffi::Bytes> smap, ffi::String fmt,
ffi::Map<ffi::String, FunctionInfo> fmap,
ffi::Map<ffi::String, ffi::String> source)
: internal_smap_(std::move(internal_smap)),
smap_(std::move(smap)),
fmt_(std::move(fmt)),
fmap_(std::move(fmap)),
source_(std::move(source)) {}
~VulkanModuleNode();
const char* kind() const final { return "vulkan"; }
/*! \brief Get the property of the runtime module. */
int GetPropertyMask() const final {
return ffi::Module::kBinarySerializable | ffi::Module::kRunnable;
}
ffi::Optional<ffi::Function> GetFunction(const ffi::String& name) final;
std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
size_t num_pack_args);
ffi::Bytes SaveToBytes() const final;
ffi::String InspectSource(const ffi::String& format) const final;
private:
// Deserialized SPIRV shaders, used by GetPipeline at runtime.
std::unordered_map<std::string, SPIRVShader> internal_smap_;
// Per-kernel serialized SPIRVShader bytes, kept for byte-identical
// SaveToBytes vs target::VulkanFallbackModuleNode. Both forms carry
// the same shaders; internal_smap_ is the deserialized cache.
ffi::Map<ffi::String, ffi::Bytes> smap_;
// The format identifier — always "vulkan" today.
ffi::String fmt_;
// function information table.
ffi::Map<ffi::String, FunctionInfo> fmap_;
// In-memory source map for InspectSource — never serialized.
ffi::Map<ffi::String, ffi::String> source_;
// Guards accesses to `ecache_`
std::mutex mutex_;
std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
ecache_;
};
} // namespace vulkan
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_VULKAN_VULKAN_WRAPPED_FUNC_H_