blob: 1528f03d8e499d1bea586c7ee0cf36a843ee1714 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <cuda.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <tvm/ffi/extra/json.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/disco/disco_worker.h>
#include "../../cuda/cuda_common.h"
namespace tvm {
namespace runtime {
ffi::Shape InitNVSHMEMUID() {
nvshmemx_uniqueid_t uid;
nvshmemx_get_uniqueid(&uid);
std::vector<int64_t> uid_64;
uid_64.push_back(static_cast<int64_t>(uid.version));
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid_64.push_back(static_cast<int64_t>(uid.internal[i]));
}
return ffi::Shape(uid_64);
}
void InitNVSHMEM(ffi::Shape uid_64, int num_workers, int worker_id_start) {
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
int worker_id;
if (worker == nullptr) {
worker_id = worker_id_start;
} else {
worker_id = worker_id_start + worker->worker_id;
}
TVM_FFI_CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1, ValueError)
<< "The length of unique_id must be " << UNIQUEID_PADDING << ", but got " << uid_64.size()
<< ".";
nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
nvshmemx_uniqueid_t uid;
uid.version = static_cast<int>(uid_64[0]);
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
}
// FIXME: this is a hack to avoid the issue of NVSHMEM using Multi-process-per-GPU to initialize
cudaSetDevice(worker_id);
nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
CUDA_CALL(cudaSetDevice(mype_node));
if (worker != nullptr) {
if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
} else {
TVM_FFI_ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA &&
worker->default_device.device_id == mype_node)
<< "The default device of the worker is inconsistent with the device used for NVSHMEM. "
<< "The default device is " << worker->default_device
<< ", but the device used for NVSHMEM is " << Device{DLDeviceType::kDLCUDA, mype_node}
<< ".";
}
}
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
<< ", npes=" << nvshmem_n_pes();
}
void InitNVSHMEMWrapper(ffi::String args) {
namespace json = tvm::ffi::json;
ffi::String err;
json::Value v = json::Parse(args, &err);
if (!err.empty()) {
TVM_FFI_THROW(InternalError) << "JSON parse error: " << err;
}
TVM_FFI_ICHECK(v.as<json::Object>()) << "JSON is not an object";
json::Object obj = v.cast<json::Object>();
json::Array uid_array = obj["uid"].cast<json::Array>();
std::vector<int64_t> uid_vector;
uid_vector.reserve(uid_array.size());
for (const ffi::Any& elem : uid_array) {
uid_vector.push_back(elem.cast<int64_t>());
}
ffi::Shape uid_64(uid_vector);
int num_workers = static_cast<int>(obj["npes"].cast<int64_t>());
int worker_id_start = static_cast<int>(obj["pe_start"].cast<int64_t>());
InitNVSHMEM(uid_64, num_workers, worker_id_start);
}
void NVSHMEMXCumoduleInit(void* cuModule) {
CUmodule mod = static_cast<CUmodule>(cuModule);
auto status = nvshmemx_init_status();
// The NVSHMEM library must have completed device initialization prior to
// nvshmemx_cumodule_init. If not, we skip the cumodule initialization.
if (status == NVSHMEM_STATUS_IS_INITIALIZED || status == NVSHMEM_STATUS_LIMITED_MPG ||
status == NVSHMEM_STATUS_FULL_MPG) {
// NOTE: we do not check the return value of nvshmemx_cumodule_init.
// The reason is because that the input cuModule might not use any NVSHMEM functions,
// in which case the nvshmemx_cumodule_init will fail.
// A set of guards to check if the module has NVSHMEM symbol to avoid the
// "gpgpu named symbol not found" error.
CUdeviceptr d_ptr;
size_t d_size;
const char* kNvshmemDeviceSymbols[] = {
"nvshmemi_device_state_d", "nvshmem_i_device_state_d",
"nvshmemi_device_team_state_d", "nvshmemi_device_heap_base_d",
"nvshmemi_device_heap_size_d", "nvshmemi_device_heap_d",
};
for (const char* sym : kNvshmemDeviceSymbols) {
if (cuModuleGetGlobal(&d_ptr, &d_size, mod, sym) == CUDA_SUCCESS) {
nvshmemx_cumodule_init(mod);
return;
}
}
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("runtime.disco.nvshmem.init_nvshmem_uid", InitNVSHMEMUID)
.def("runtime.disco.nvshmem.init_nvshmem", InitNVSHMEM)
.def("runtime.disco.nvshmem.init_nvshmem_wrapper", InitNVSHMEMWrapper)
.def("runtime.nvshmem.cumodule_init", NVSHMEMXCumoduleInit);
}
} // namespace runtime
} // namespace tvm