blob: def2522b60d34f34e20802b3cae1604e70b523db [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <numeric>
#include "../../../support/socket.h"
#include "../bcast_session.h"
#include "../message_queue.h"
namespace tvm {
namespace runtime {
using namespace tvm::support;
enum class DiscoSocketAction {
kShutdown = static_cast<int>(DiscoAction::kShutDown),
kSend,
kReceive,
};
class DiscoSocketChannel : public DiscoChannel {
public:
explicit DiscoSocketChannel(const TCPSocket& socket)
: socket_(socket), message_queue_(&socket_) {}
DiscoSocketChannel(DiscoSocketChannel&& other) = delete;
DiscoSocketChannel(const DiscoSocketChannel& other) = delete;
void Send(const ffi::PackedArgs& args) { message_queue_.Send(args); }
ffi::PackedArgs Recv() { return message_queue_.Recv(); }
void Reply(const ffi::PackedArgs& args) { message_queue_.Send(args); }
ffi::PackedArgs RecvReply() { return message_queue_.Recv(); }
private:
TCPSocket socket_;
DiscoStreamMessageQueue message_queue_;
};
class SocketSessionObj : public BcastSessionObj {
public:
explicit SocketSessionObj(int num_nodes, int num_workers_per_node, int num_groups,
const ffi::String& host, int port)
: num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) {
const auto f_create_local_session =
tvm::ffi::Function::GetGlobal("runtime.disco.create_socket_session_local_workers");
TVM_FFI_ICHECK(f_create_local_session.has_value())
<< "Cannot find function runtime.disco.create_socket_session_local_workers";
local_session_ = ((*f_create_local_session)(num_workers_per_node)).cast<BcastSession>();
DRef f_init_workers =
local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers");
local_session_->CallPacked(f_init_workers, num_nodes_, /*node_id=*/0, num_groups,
num_workers_per_node_);
Socket::Startup();
socket_.Create();
socket_.SetKeepAlive(true);
socket_.Bind(SockAddr(host.c_str(), port));
socket_.Listen();
LOG(INFO) << "SocketSession controller listening on " << host << ":" << port;
AnyView packed_args[4];
packed_args[0] = num_nodes;
packed_args[1] = num_workers_per_node;
packed_args[2] = num_groups;
for (int i = 0; i + 1 < num_nodes; ++i) {
SockAddr addr;
remote_sockets_.push_back(socket_.Accept(&addr));
remote_channels_.emplace_back(std::make_unique<DiscoSocketChannel>(remote_sockets_.back()));
packed_args[3] = i + 1;
// Send metadata to each remote node:
// - num_nodes
// - num_workers_per_node
// - num_groups
// - node_id
remote_channels_.back()->Send(ffi::PackedArgs(packed_args, 4));
LOG(INFO) << "Remote node " << addr.AsString() << " connected";
}
}
int64_t GetNumWorkers() final { return num_nodes_ * num_workers_per_node_; }
ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) final {
int node_id = worker_id / num_workers_per_node_;
if (node_id == 0) {
return local_session_->DebugGetFromRemote(reg_id, worker_id);
} else {
AnyView packed_args[5];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoSocketAction::kSend), worker_id,
static_cast<int>(DiscoAction::kDebugGetFromRemote), reg_id, worker_id);
remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args, 5));
ffi::PackedArgs args = this->RecvReplyPacked(worker_id);
TVM_FFI_ICHECK_EQ(args.size(), 2);
TVM_FFI_ICHECK(static_cast<DiscoAction>(args[0].cast<int>()) ==
DiscoAction::kDebugGetFromRemote);
ffi::Any result;
result = args[1];
return result;
}
}
void DebugSetRegister(int64_t reg_id, AnyView value, int worker_id) final {
int node_id = worker_id / num_workers_per_node_;
if (node_id == 0) {
local_session_->DebugSetRegister(reg_id, value, worker_id);
} else {
ObjectRef wrapped{nullptr};
if (auto opt_obj = value.as<ObjectRef>()) {
wrapped = DiscoDebugObject::Wrap(value);
value = wrapped;
}
{
AnyView packed_args[6];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoSocketAction::kSend), worker_id,
static_cast<int>(DiscoAction::kDebugSetRegister), reg_id, worker_id,
value);
remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args, 6));
}
ffi::Any result;
ffi::PackedArgs args = this->RecvReplyPacked(worker_id);
TVM_FFI_ICHECK_EQ(args.size(), 1);
TVM_FFI_ICHECK(static_cast<DiscoAction>(args[0].cast<int>()) ==
DiscoAction::kDebugSetRegister);
}
}
void BroadcastPacked(const ffi::PackedArgs& args) final {
local_session_->BroadcastPacked(args);
std::vector<AnyView> packed_args(args.size() + 2);
ffi::PackedArgs::Fill(packed_args.data(), static_cast<int>(DiscoSocketAction::kSend), -1);
std::copy(args.data(), args.data() + args.size(), packed_args.begin() + 2);
for (auto& channel : remote_channels_) {
channel->Send(ffi::PackedArgs(packed_args.data(), packed_args.size()));
}
}
void SendPacked(int worker_id, const ffi::PackedArgs& args) final {
int node_id = worker_id / num_workers_per_node_;
if (node_id == 0) {
local_session_->SendPacked(worker_id, args);
return;
}
std::vector<AnyView> packed_args(args.size() + 2);
ffi::PackedArgs::Fill(packed_args.data(), static_cast<int>(DiscoSocketAction::kSend),
worker_id);
std::copy(args.data(), args.data() + args.size(), packed_args.begin() + 2);
remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args.data(), packed_args.size()));
}
ffi::PackedArgs RecvReplyPacked(int worker_id) final {
int node_id = worker_id / num_workers_per_node_;
if (node_id == 0) {
return local_session_->RecvReplyPacked(worker_id);
}
AnyView packed_args[2];
ffi::PackedArgs::Fill(packed_args, static_cast<int>(DiscoSocketAction::kReceive), worker_id);
remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args, 2));
return remote_channels_[node_id - 1]->Recv();
}
void AppendHostTensor(const Tensor& host_array) final {
local_session_->AppendHostTensor(host_array);
}
void Shutdown() final {
// local session will be implicitly shutdown by its destructor
std::vector<AnyView> packed_args(2);
ffi::PackedArgs::Fill(packed_args.data(), static_cast<int>(DiscoSocketAction::kShutdown), -1);
for (auto& channel : remote_channels_) {
channel->Send(ffi::PackedArgs(packed_args.data(), packed_args.size()));
}
for (auto& socket : remote_sockets_) {
socket.Close();
}
remote_sockets_.clear();
remote_channels_.clear();
if (!socket_.IsClosed()) {
socket_.Close();
}
Socket::Finalize();
}
~SocketSessionObj() { Shutdown(); }
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.disco.SocketSession", SocketSessionObj,
BcastSessionObj);
int num_nodes_;
int num_workers_per_node_;
TCPSocket socket_;
std::vector<TCPSocket> remote_sockets_;
std::vector<std::unique_ptr<DiscoSocketChannel>> remote_channels_;
BcastSession local_session_{nullptr};
};
class RemoteSocketSession {
public:
explicit RemoteSocketSession(const ffi::String& server_host, int server_port,
int num_local_workers) {
socket_.Create();
socket_.SetKeepAlive(true);
SockAddr server_addr{server_host.c_str(), server_port};
Socket::Startup();
if (!socket_.Connect(server_addr)) {
TVM_FFI_THROW(InternalError) << "Failed to connect to server " << server_addr.AsString()
<< ", errno = " << Socket::GetLastErrorCode();
}
channel_ = std::make_unique<DiscoSocketChannel>(socket_);
ffi::PackedArgs metadata = channel_->Recv();
TVM_FFI_ICHECK_EQ(metadata.size(), 4);
num_nodes_ = metadata[0].cast<int>();
num_workers_per_node_ = metadata[1].cast<int>();
num_groups_ = metadata[2].cast<int>();
node_id_ = metadata[3].cast<int>();
TVM_FFI_ICHECK_GE(num_local_workers, num_workers_per_node_);
InitLocalSession();
}
void MainLoop() {
while (true) {
ffi::PackedArgs args = channel_->Recv();
DiscoSocketAction action = static_cast<DiscoSocketAction>(args[0].cast<int>());
int worker_id = args[1].cast<int>();
int local_worker_id = worker_id - node_id_ * num_workers_per_node_;
switch (action) {
case DiscoSocketAction::kSend: {
args = args.Slice(2);
if (worker_id == -1) {
local_session_->BroadcastPacked(args);
} else {
local_session_->SendPacked(local_worker_id, args);
}
break;
}
case DiscoSocketAction::kReceive: {
args = local_session_->RecvReplyPacked(local_worker_id);
channel_->Reply(args);
break;
}
case DiscoSocketAction::kShutdown: {
local_session_->Shutdown();
LOG(INFO) << "Connection closed by remote controller.";
return;
}
default:
TVM_FFI_THROW(InternalError) << "Invalid action " << static_cast<int>(action);
}
}
}
~RemoteSocketSession() {
socket_.Close();
Socket::Finalize();
}
private:
void InitLocalSession() {
const auto f_create_local_session =
tvm::ffi::Function::GetGlobal("runtime.disco.create_socket_session_local_workers");
local_session_ = ((*f_create_local_session)(num_workers_per_node_)).cast<BcastSession>();
DRef f_init_workers =
local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers");
local_session_->CallPacked(f_init_workers, num_nodes_, node_id_, num_groups_,
num_workers_per_node_);
}
TCPSocket socket_;
BcastSession local_session_{nullptr};
std::unique_ptr<DiscoSocketChannel> channel_;
int num_nodes_{-1};
int node_id_{-1};
int num_groups_{-1};
int num_workers_per_node_{-1};
};
void RemoteSocketSessionEntryPoint(const ffi::String& server_host, int server_port,
int num_local_workers) {
RemoteSocketSession proxy(server_host, server_port, num_local_workers);
proxy.MainLoop();
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("runtime.disco.RemoteSocketSession", RemoteSocketSessionEntryPoint);
}
Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups,
const ffi::String& host, int port) {
auto n =
ffi::make_object<SocketSessionObj>(num_nodes, num_workers_per_node, num_groups, host, port);
return Session(n);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<SocketSessionObj>();
refl::GlobalDef()
.def("runtime.disco.SocketSession", SocketSession)
.def("runtime.disco.socket_session_init_workers",
[](int num_nodes, int node_id, int num_groups, int num_workers_per_node) {
LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, "
<< num_workers_per_node << " workers per node, and " << num_groups
<< " groups.";
DiscoWorker* worker = DiscoWorker::ThreadLocal();
worker->num_groups = num_groups;
worker->worker_id = worker->worker_id + node_id * num_workers_per_node;
worker->num_workers = num_nodes * num_workers_per_node;
});
}
} // namespace runtime
} // namespace tvm