blob: aca1fef90c94990d30c97c61d91cf67262be9606 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/object.h>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "../../support/pipe.h"
#include "../minrpc/rpc_reference.h"
#include "./bcast_session.h"
#include "./disco_worker_thread.h"
#include "./message_queue.h"
#include "./protocol.h"
namespace tvm {
namespace runtime {
class DiscoProcessChannel final : public DiscoChannel {
public:
DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t worker_to_controler_fd)
: controller_to_worker_pipe_(controler_to_worker_fd),
worker_to_controller_pipe_(worker_to_controler_fd),
controler_to_worker_(&controller_to_worker_pipe_),
worker_to_controler_(&worker_to_controller_pipe_) {}
DiscoProcessChannel(DiscoProcessChannel&& other) = delete;
DiscoProcessChannel(const DiscoProcessChannel& other) = delete;
void Send(const ffi::PackedArgs& args) { controler_to_worker_.Send(args); }
ffi::PackedArgs Recv() { return controler_to_worker_.Recv(); }
void Reply(const ffi::PackedArgs& args) { worker_to_controler_.Send(args); }
ffi::PackedArgs RecvReply() { return worker_to_controler_.Recv(); }
support::Pipe controller_to_worker_pipe_;
support::Pipe worker_to_controller_pipe_;
DiscoStreamMessageQueue controler_to_worker_;
DiscoStreamMessageQueue worker_to_controler_;
};
class ProcessSessionObj final : public BcastSessionObj {
public:
explicit ProcessSessionObj(int num_workers, int num_groups, ffi::Function process_pool)
: process_pool_(process_pool),
worker_0_(
std::make_unique<DiscoWorkerThread>(0, num_workers, num_groups, &worker_zero_data_)) {
std::vector<int64_t> read_fds;
std::vector<int64_t> write_fds;
read_fds.reserve(num_workers - 1);
write_fds.reserve(num_workers - 1);
for (int i = 1; i < num_workers; ++i) {
ffi::Shape fds = process_pool(i).cast<ffi::Shape>();
CHECK_EQ(fds.size(), 2) << "ValueError: process_pool(" << i << ") should return a tuple of "
<< "size 2, but got a tuple of size " << fds.size() << ".";
read_fds.push_back(fds[0]);
write_fds.push_back(fds[1]);
}
for (int i = 0; i < num_workers - 1; ++i) {
workers_.emplace_back(std::make_unique<DiscoProcessChannel>(write_fds[i], read_fds[i]));
}
}
void Kill() {
if (this->worker_0_ != nullptr) {
this->Shutdown();
this->worker_0_.reset();
this->workers_.clear();
this->process_pool_(0);
}
}
~ProcessSessionObj() { Kill(); }
int64_t GetNumWorkers() { return workers_.size() + 1; }
ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) {
if (worker_id == 0) {
this->SyncWorker(worker_id);
return worker_0_->worker->register_file.at(reg_id);
}
{
ffi::AnyView packed_args[3];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoAction::kDebugGetFromRemote), reg_id,
worker_id);
workers_[worker_id - 1]->Send(ffi::PackedArgs(packed_args, 3));
}
ffi::PackedArgs args = this->RecvReplyPacked(worker_id);
ICHECK_EQ(args.size(), 2);
ICHECK(static_cast<DiscoAction>(args[0].cast<int>()) == DiscoAction::kDebugGetFromRemote);
ffi::Any result;
result = args[1];
return result;
}
void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) {
if (worker_id == 0) {
this->SyncWorker(worker_id);
worker_0_->worker->SetRegister(reg_id, value);
return;
}
ObjectRef wrapped{nullptr};
if (value.as<ObjectRef>()) {
wrapped = DiscoDebugObject::Wrap(value);
value = wrapped;
}
{
ffi::AnyView packed_args[4];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoAction::kDebugSetRegister), reg_id,
worker_id, value);
SendPacked(worker_id, ffi::PackedArgs(packed_args, 4));
}
ffi::Any result;
ffi::PackedArgs args = this->RecvReplyPacked(worker_id);
ICHECK_EQ(args.size(), 1);
ICHECK(static_cast<DiscoAction>(args[0].cast<int>()) == DiscoAction::kDebugSetRegister);
}
void BroadcastPacked(const ffi::PackedArgs& args) final {
worker_0_->channel->Send(args);
for (std::unique_ptr<DiscoProcessChannel>& channel : workers_) {
channel->Send(args);
}
}
void SendPacked(int worker_id, const ffi::PackedArgs& args) final {
if (worker_id == 0) {
worker_0_->channel->Send(args);
} else {
workers_.at(worker_id - 1)->Send(args);
}
}
ffi::PackedArgs RecvReplyPacked(int worker_id) final {
if (worker_id == 0) {
return worker_0_->channel->RecvReply();
}
return this->workers_.at(worker_id - 1)->RecvReply();
}
DiscoChannel* GetWorkerChannel(int worker_id) {
if (worker_id == 0) {
return worker_0_->channel.get();
}
return workers_.at(worker_id - 1).get();
}
ffi::Function process_pool_;
std::unique_ptr<DiscoWorkerThread> worker_0_;
std::vector<std::unique_ptr<DiscoProcessChannel>> workers_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.ProcessSession", ProcessSessionObj, SessionObj);
};
Session Session::ProcessSession(int num_workers, int num_group, ffi::String process_pool_creator,
ffi::String entrypoint) {
CHECK_EQ(num_workers % num_group, 0)
<< "The number of workers should be divisible by the number of worker group.";
const auto pf = tvm::ffi::Function::GetGlobal(process_pool_creator);
CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator
<< " in the registry. Please check if it is registered.";
auto process_pool = (*pf)(num_workers, num_group, entrypoint).cast<ffi::Function>();
auto n = ffi::make_object<ProcessSessionObj>(num_workers, num_group, process_pool);
return Session(n);
}
void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_fd,
int64_t write_fd) {
CHECK_EQ(num_workers % num_group, 0)
<< "The number of workers should be divisible by the number of worker group.";
DiscoProcessChannel channel(read_fd, write_fd);
DiscoWorker worker(worker_id, num_workers, num_group, nullptr, &channel);
worker.MainLoop();
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("runtime.disco.SessionProcess", Session::ProcessSession)
.def("runtime.disco.WorkerProcess", WorkerProcess);
}
} // namespace runtime
} // namespace tvm