blob: 487ad23e16b97dc53258635389a1f131e73db028 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rocm_module.cc
*/
#include "rocm_module.h"
#include <hip/hip_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <array>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "../file_utils.h"
#include "../meta_data.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "rocm_common.h"
namespace tvm {
namespace runtime {
// Module to support thread-safe multi-GPU execution.
// hipModule_t is a per-GPU module
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class ROCMModuleNode : public runtime::ModuleNode {
public:
explicit ROCMModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string hip_source, std::string assembly)
: data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) {
std::fill(module_.begin(), module_.end(), nullptr);
}
// destructor
~ROCMModuleNode() {
for (size_t i = 0; i < module_.size(); ++i) {
if (module_[i] != nullptr) {
ROCM_CALL(hipSetDevice(static_cast<int>(i)));
ROCM_DRIVER_CALL(hipModuleUnload(module_[i]));
}
}
}
const char* type_key() const final { return "hip"; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
// note: llvm and asm formats are not laodable, so we don't save them
ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}
std::string GetSource(const std::string& format) final {
if (format == fmt_) {
return data_;
}
if (format == "llvm" || format == "") {
return hip_source_;
}
if (format == "asm") {
return assembly_;
}
return "";
}
// get a CUfunction from primary context in device_id
hipFunction_t GetFunc(int device_id, const std::string& func_name) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
}
hipFunction_t func;
hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str());
if (result != hipSuccess) {
LOG(FATAL) << "ROCMError: hipModuleGetFunction " << func_name
<< " failed with error: " << hipGetErrorString(result);
}
return func;
}
// get a global var from primary context in device_id
hipDeviceptr_t GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
}
hipDeviceptr_t global = nullptr;
size_t nbytes = 0;
ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str()));
ICHECK_EQ(nbytes, expect_nbytes);
return global;
}
private:
// the binary data
std::string data_;
// The format
std::string fmt_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The hip source.
std::string hip_source_;
// The gcn asm.
std::string assembly_;
// the internal modules per GPU, to be lazily initialized.
std::array<hipModule_t, kMaxNumGPUs> module_;
// internal mutex when updating the module
std::mutex mutex_;
};
// a wrapped function class to get packed func.
class ROCMWrappedFunc {
public:
// initialize the ROCM function.
void Init(ROCMModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_void_args, const std::vector<std::string>& launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
launch_param_config_.Init(num_void_args, launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const {
int device_id;
ROCM_CALL(hipGetDevice(&device_id));
if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE,
&packed_nbytes, HIP_LAUNCH_PARAM_END};
// HIP supports only extra_args.
ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr,
reinterpret_cast<void**>(&config)));
}
private:
// internal module
ROCMModuleNode* m_;
// the resource holder
ObjectPtr<Object> sptr_;
// The name of the function.
std::string func_name_;
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
};
PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
ROCMWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncPackedArg(f, info.arg_types);
}
Module ROCMModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string hip_source,
std::string assembly) {
auto n = make_object<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
return Module(n);
}
Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
}
Module ROCMModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&data);
return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
}
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary);
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary);
TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile);
TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile);
} // namespace runtime
} // namespace tvm