blob: a7ff1a5a035ae8ee09ebe641a1dc89e19980917b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/function.h>
#include <tvm/runtime/disco/builtin.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/disco/session.h>
#include "../../support/process_id.h"
#include "./protocol.h"
namespace tvm {
namespace runtime {
TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() {
DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
TVM_FFI_CHECK(ret, ValueError) << "The current thread is not a DiscoWorker thread";
return ret;
}
void DiscoWorker::SetRegister(int reg_id, ffi::AnyView value) {
TVM_FFI_ICHECK(0 <= reg_id && reg_id < static_cast<int>(register_file.size()));
ffi::Any& rv = register_file.at(reg_id);
if (rv.type_index() == ffi::TypeIndex::kTVMFFITensor &&
value.type_index() == ffi::TypeIndex::kTVMFFITensor) {
Tensor dst = rv.cast<Tensor>();
Tensor src = value.cast<Tensor>();
dst.CopyFrom(src);
} else {
rv = value;
}
}
struct DiscoWorker::Impl {
static void MainLoop(DiscoWorker* self) {
ThreadLocalDiscoWorker::Get()->worker = self;
using namespace tvm;
while (true) {
ffi::PackedArgs args = self->channel->Recv();
DiscoAction action = static_cast<DiscoAction>(args[0].cast<int>());
int64_t reg_id = args[1].cast<int64_t>();
switch (action) {
case DiscoAction::kShutDown: {
Shutdown(self);
return;
}
case DiscoAction::kKillReg: {
GetReg(self, reg_id) = nullptr;
break;
}
case DiscoAction::kGetGlobalFunc: {
GetGlobalFunc(self, reg_id, args[2].cast<std::string>());
break;
}
case DiscoAction::kCallPacked: {
int func_reg_id = args[2].cast<int>();
TVM_FFI_ICHECK_LT(func_reg_id, self->register_file.size());
ffi::Function func = GetReg(self, func_reg_id).cast<ffi::Function>();
TVM_FFI_ICHECK(func.defined());
CallPacked(self, reg_id, func, args.Slice(3));
break;
}
case DiscoAction::kCopyFromWorker0: {
CopyFromWorker0(self, reg_id);
break;
}
case DiscoAction::kCopyToWorker0: {
CopyToWorker0(self, reg_id);
break;
}
case DiscoAction::kSyncWorker: {
SyncWorker(self, reg_id);
break;
}
case DiscoAction::kDebugGetFromRemote: {
int worker_id = args[2].cast<int>();
DebugGetFromRemote(self, reg_id, worker_id);
break;
}
case DiscoAction::kDebugSetRegister: {
int worker_id = args[2].cast<int>();
ffi::AnyView value = args[3];
DebugSetRegister(self, reg_id, worker_id, value);
break;
}
}
}
}
static void Shutdown(DiscoWorker* self) {}
static void GetGlobalFunc(DiscoWorker* self, int reg_id, const std::string& name) {
const auto pf = tvm::ffi::Function::GetGlobal(name);
TVM_FFI_CHECK(pf.has_value(), ValueError) << "Cannot find global function: " << name;
if (reg_id != 0) {
GetReg(self, reg_id) = *pf;
}
}
static Tensor GetTensorFromHost(DiscoWorker* self) {
std::lock_guard<std::mutex> lock(self->worker_zero_data->queue_mutex_);
Tensor array = self->worker_zero_data->host_arrays.front();
self->worker_zero_data->host_arrays.pop();
return array;
}
static void CopyFromWorker0(DiscoWorker* self, int reg_id) {
if (self->worker_id == 0) {
Tensor tgt = GetTensorFromHost(self);
Tensor src = GetReg(self, reg_id).cast<Tensor>();
tgt.CopyFrom(src);
}
}
static void CopyToWorker0(DiscoWorker* self, int reg_id) {
if (self->worker_id == 0) {
Tensor src = GetTensorFromHost(self);
Tensor tgt = GetReg(self, reg_id).cast<Tensor>();
tgt.CopyFrom(src);
}
}
static void SyncWorker(DiscoWorker* self, int worker_id) {
if (worker_id == self->worker_id) {
::tvm::runtime::SyncWorker();
ffi::AnyView packed_args[2];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoAction::kSyncWorker), worker_id);
self->channel->Reply(ffi::PackedArgs(packed_args, 2));
}
}
static void DebugGetFromRemote(DiscoWorker* self, int reg_id, int worker_id) {
if (worker_id == self->worker_id) {
ffi::Any rv = GetReg(self, reg_id);
if (rv.as<ObjectRef>()) {
rv = DiscoDebugObject::Wrap(rv);
}
ffi::AnyView packed_args[2];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoAction::kDebugGetFromRemote), rv);
self->channel->Reply(ffi::PackedArgs(packed_args, 2));
}
}
static void DebugSetRegister(DiscoWorker* self, int reg_id, int worker_id, ffi::AnyView value) {
if (worker_id == self->worker_id) {
::tvm::runtime::SyncWorker();
self->SetRegister(reg_id, value);
ffi::AnyView packed_args[1];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoAction::kDebugSetRegister));
self->channel->Reply(ffi::PackedArgs(packed_args, 1));
}
}
static void CallPacked(DiscoWorker* self, int64_t ret_reg_id, ffi::Function func,
const ffi::PackedArgs& args) {
// NOTE: this action is not safe unless we know args is not
// used else where in this case it is oK
ffi::AnyView* args_vec = const_cast<ffi::AnyView*>(args.data());
// translate args into remote calling convention
for (int i = 0; i < args.size(); ++i) {
if (auto opt_dref = args_vec[i].as<DRef>()) {
DRef dref = opt_dref.value();
args_vec[i] = GetReg(self, dref->reg_id);
}
}
ffi::Any rv;
func.CallPacked(ffi::PackedArgs(args_vec, args.size()), &rv);
GetReg(self, ret_reg_id) = std::move(rv);
}
static ffi::Any& GetReg(DiscoWorker* self, int64_t reg_id) {
if (reg_id >= static_cast<int64_t>(self->register_file.size())) {
self->register_file.resize(reg_id + 1);
}
return self->register_file[reg_id];
}
};
void DiscoWorker::MainLoop() { DiscoWorker::Impl::MainLoop(this); }
} // namespace runtime
} // namespace tvm