blob: 8584d15c5e041a2765692fd7abbf4b6340db6b81 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dlpack/dlpack.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/disco/builtin.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/disco/session.h>
#include <tvm/runtime/vm/vm.h>
#include <sstream>
#include "./utils.h"
namespace tvm {
namespace runtime {
class DSOLibraryCache {
public:
ffi::Module Open(const std::string& library_path) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = cache_.find(library_path);
if (it == cache_.end()) {
ffi::Module lib = ffi::Module::LoadFromFile(library_path);
cache_.emplace(library_path, lib);
return lib;
}
return it->second;
}
std::unordered_map<std::string, ffi::Module> cache_;
std::mutex mutex_;
};
ffi::Module LoadVMModule(std::string path, ffi::Optional<Device> device) {
static DSOLibraryCache cache;
ffi::Module dso_mod = cache.Open(path);
Device dev = UseDefaultDeviceIfNone(device);
ffi::Optional<ffi::Function> vm_load_executable = dso_mod->GetFunction("vm_load_executable");
if (!vm_load_executable.has_value()) {
// not built by RelaxVM, return the dso_mod directly
return dso_mod;
}
auto mod = (*vm_load_executable)().cast<ffi::Module>();
ffi::Optional<ffi::Function> vm_initialization = mod->GetFunction("vm_initialization");
if (!vm_initialization.has_value()) {
LOG(FATAL) << "ValueError: File `" << path
<< "` is not built by RelaxVM, because `vm_initialization` does not exist";
}
(*vm_initialization)(static_cast<int>(dev.device_type), static_cast<int>(dev.device_id),
static_cast<int>(AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
static_cast<int>(AllocatorType::kPooled));
return mod;
}
Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, ffi::Optional<Device> device) {
return Tensor::Empty(shape, dtype, UseDefaultDeviceIfNone(device));
}
ffi::Function GetCCLFunc(const char* name) {
std::string ccl = DiscoWorker::ThreadLocal()->ccl;
std::string pf_name = "runtime.disco." + ccl + "." + name;
const auto pf = tvm::ffi::Function::GetGlobal(pf_name);
CHECK(pf.has_value()) << "ValueError: Cannot find the `" << name << "` function for `" << ccl
<< "` via `" << pf_name << "`";
return *pf;
}
void AllReduce(Tensor send, ReduceKind reduce_kind, bool in_group, Tensor recv) {
GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), in_group, recv);
}
void AllGather(Tensor send, bool in_group, Tensor recv) {
GetCCLFunc("allgather")(send, in_group, recv);
}
TVM_DLL void BroadcastFromWorker0(Tensor send, bool in_group, Tensor recv) {
GetCCLFunc("broadcast_from_worker0")(send, in_group, recv);
}
TVM_DLL void ScatterFromWorker0(ffi::Optional<Tensor> send, bool in_group, Tensor recv) {
GetCCLFunc("scatter_from_worker0")(send, in_group, recv);
}
void GatherToWorker0(Tensor send, bool in_group, ffi::Optional<Tensor> recv) {
GetCCLFunc("gather_to_worker0")(send, in_group, recv);
}
void RecvFromWorker0(Tensor buffer) { GetCCLFunc("recv_from_worker0")(buffer); }
void SendToNextGroup(Tensor buffer) { GetCCLFunc("send_to_next_group")(buffer); }
void RecvFromPrevGroup(Tensor buffer) { GetCCLFunc("recv_from_prev_group")(buffer); }
void SendToWorker(Tensor buffer, int receiver_id) {
GetCCLFunc("send_to_worker")(buffer, receiver_id);
}
void RecvFromWorker(Tensor buffer, int sender_id) {
GetCCLFunc("recv_from_worker")(buffer, sender_id);
}
int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
void SyncWorker() {
if (DiscoWorker::ThreadLocal()->ccl != "") {
GetCCLFunc("sync_worker")();
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("runtime.disco.load_vm_module", LoadVMModule)
.def("runtime.disco.empty",
[](ffi::Shape shape, DataType dtype, ffi::Optional<Device> device, bool worker0_only,
bool in_group) -> ffi::Optional<Tensor> {
int worker_id = WorkerId();
int group_size =
DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups;
bool is_worker0 =
(worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0);
if (worker0_only && !is_worker0) {
return std::nullopt;
} else {
return DiscoEmptyTensor(shape, dtype, device);
}
})
.def("runtime.disco.allreduce",
[](Tensor send, ffi::Shape reduce_kind, bool in_group, Tensor recv) {
int kind = IntegerFromShape(reduce_kind);
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
AllReduce(send, static_cast<ReduceKind>(kind), in_group, recv);
})
.def("runtime.disco.allgather", AllGather)
.def("runtime.disco.broadcast_from_worker0", BroadcastFromWorker0)
.def("runtime.disco.scatter_from_worker0", ScatterFromWorker0)
.def("runtime.disco.gather_to_worker0", GatherToWorker0)
.def("runtime.disco.recv_from_worker0", RecvFromWorker0)
.def("runtime.disco.send_to_next_group", SendToNextGroup)
.def("runtime.disco.recv_from_prev_group", RecvFromPrevGroup)
.def("runtime.disco.send_to_worker", SendToWorker)
.def("runtime.disco.recv_from_worker", RecvFromWorker)
.def("runtime.disco.worker_id", []() -> ffi::Shape { return ffi::Shape({WorkerId()}); })
.def("runtime.disco.worker_rank", []() -> int64_t { return WorkerId(); })
.def("runtime.disco.device",
[]() -> Device { return DiscoWorker::ThreadLocal()->default_device; })
.def("runtime.disco.bind_worker_to_cpu_core", [](ffi::Shape cpu_ids) {
int worker_id = WorkerId();
ICHECK_LT(worker_id, static_cast<int>(cpu_ids.size()));
const auto f_set_thread_affinity = tvm::ffi::Function::GetGlobalRequired(
"tvm.runtime.threading.set_current_thread_affinity");
f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]});
});
}
} // namespace runtime
} // namespace tvm