| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file spirv_utils.cc |
| * \brief Build SPIRV block |
| */ |
| // Use libspirv for parsing and validating code. |
| #include "spirv_utils.h" |
| |
| #if TVM_ENABLE_SPIRV |
| #include <libspirv.h> |
| |
| #include "codegen_spirv.h" |
| #endif |
| |
| #include <tvm/tirx/transform.h> |
| |
| #include <algorithm> |
| #include <fstream> |
| #include <sstream> |
| #include <vector> |
| |
| #include "../../runtime/vulkan/spirv_shader.h" |
| #include "../../support/utils.h" |
| |
| namespace tvm { |
| namespace codegen { |
| |
| #if TVM_ENABLE_SPIRV |
| |
| class SPIRVTools { |
| public: |
| explicit SPIRVTools(Target target) { |
| uint32_t vulkan_version = |
| target->GetAttr<Integer>("vulkan_api_version").value_or(VK_API_VERSION_1_0).IntValue(); |
| uint32_t spirv_version = |
| target->GetAttr<Integer>("max_spirv_version").value_or(0x10000).IntValue(); |
| |
| spv_target_env validation_version; |
| if (target->kind->name == "opencl") { |
| validation_version = SPV_ENV_OPENCL_2_2; |
| } else { |
| if (vulkan_version >= VK_API_VERSION_1_2) { |
| validation_version = SPV_ENV_VULKAN_1_2; |
| } else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) { |
| validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4; |
| } else if (vulkan_version >= VK_API_VERSION_1_1) { |
| validation_version = SPV_ENV_VULKAN_1_1; |
| } else { |
| validation_version = SPV_ENV_VULKAN_1_0; |
| } |
| } |
| ctx_ = spvContextCreate(validation_version); |
| } |
| |
| ~SPIRVTools() { spvContextDestroy(ctx_); } |
| |
| std::string BinaryToText(const std::vector<uint32_t>& bin) { |
| spv_text text = nullptr; |
| spv_diagnostic diagnostic = nullptr; |
| spv_const_binary_t spv_bin{bin.data(), bin.size()}; |
| |
| spv_result_t res = |
| spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, |
| SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, |
| &text, &diagnostic); |
| |
| TVM_FFI_ICHECK_EQ(res, SPV_SUCCESS) |
| << " line=" << diagnostic->position.line << " column=" << diagnostic->position.column |
| << " index=" << diagnostic->position.index << " error:" << diagnostic->error; |
| spvDiagnosticDestroy(diagnostic); |
| |
| std::string ret(text->str); |
| spvTextDestroy(text); |
| return ret; |
| } |
| |
| void ValidateShader(const std::vector<uint32_t>& bin) { |
| spv_const_binary_t spv_bin{bin.data(), bin.size()}; |
| |
| spv_diagnostic diagnostic = nullptr; |
| spv_result_t res = spvValidate(ctx_, &spv_bin, &diagnostic); |
| |
| TVM_FFI_ICHECK_EQ(res, SPV_SUCCESS) |
| << " index=" << diagnostic->position.index << " error:" << diagnostic->error; |
| |
| spvDiagnosticDestroy(diagnostic); |
| } |
| |
| private: |
| spv_context ctx_; |
| }; |
| |
| std::pair<std::unordered_map<std::string, runtime::SPIRVShader>, std::string> LowerToSPIRV( |
| IRModule mod, Target target) { |
| using tvm::runtime::SPIRVShader; |
| |
| std::ostringstream code_data; |
| SPIRVTools spirv_tools(target); |
| std::unordered_map<std::string, SPIRVShader> smap; |
| |
| auto postproc = tvm::ffi::Function::GetGlobal("tvm_callback_vulkan_postproc"); |
| |
| mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod)); |
| |
| CodeGenSPIRV cg(target); |
| |
| for (auto kv : mod->functions) { |
| TVM_FFI_ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenSPIRV: Can only take PrimFunc"; |
| auto f = Downcast<PrimFunc>(kv.second); |
| auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); |
| TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) |
| << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; |
| auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol); |
| TVM_FFI_ICHECK(global_symbol.has_value()) |
| << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; |
| |
| std::string f_name = global_symbol.value(); |
| std::string entry = f_name; |
| |
| SPIRVShader shader = cg.BuildFunction(f, entry); |
| |
| if (auto path = std::getenv("TVM_VULKAN_DEBUG_SHADER_SAVEPATH")) { |
| if (*path) { |
| std::stringstream ss; |
| ss << path << "/" << f_name << "_"; |
| std::string prefix = ss.str(); |
| |
| std::ofstream(prefix + "tirx.txt") << f; |
| std::ofstream(prefix + "spv.txt") << spirv_tools.BinaryToText(shader.data); |
| std::ofstream(prefix + "spv.spv", std::ios::binary) |
| .write(reinterpret_cast<const char*>(shader.data.data()), |
| sizeof(shader.data[0]) * shader.data.size()); |
| } |
| } |
| |
| if (!support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_SHADER_VALIDATION")) { |
| spirv_tools.ValidateShader(shader.data); |
| } |
| |
| if (postproc) { |
| TVMFFIByteArray arr; |
| arr.data = reinterpret_cast<const char*>(shader.data.data()); |
| arr.size = shader.data.size() * sizeof(uint32_t); |
| std::string transformed = (*postproc)(&arr, target).cast<std::string>(); |
| TVM_FFI_ICHECK_EQ(transformed.length() % 4U, 0U); |
| shader.data.resize(transformed.size() / 4U); |
| std::copy(transformed.begin(), transformed.end(), |
| reinterpret_cast<char*>(shader.data.data())); |
| } |
| code_data << spirv_tools.BinaryToText(shader.data); |
| smap[f_name] = std::move(shader); |
| } |
| |
| return std::make_pair(smap, code_data.str()); |
| } |
| |
| #else |
| |
| std::pair<std::unordered_map<std::string, runtime::SPIRVShader>, std::string> LowerToSPIRV( |
| IRModule mod, Target target) { |
| TVM_FFI_THROW(InternalError) |
| << "LowerToSPIRV is called but SPIRV codegen is not enabled. Please set -DUSE_VULKAN=ON."; |
| return {}; |
| } |
| |
| #endif |
| |
| } // namespace codegen |
| } // namespace tvm |