blob: aa75145ce2fc13d616c2a365ee64ef7420c58e8b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_socket_impl.cc
* \brief Socket based RPC implementation.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/support/io.h>
#include <tvm/support/serializer.h>
#include <memory>
#include "../../support/socket.h"
#include "rpc_endpoint.h"
#include "rpc_local_session.h"
#include "rpc_session.h"
namespace tvm {
namespace runtime {
class SockChannel final : public RPCChannel {
public:
explicit SockChannel(support::TCPSocket sock) : sock_(sock) {}
~SockChannel() {
try {
// BadSocket can throw
if (!sock_.BadSocket()) {
sock_.Close();
}
} catch (...) {
}
}
size_t Send(const void* data, size_t size) final {
ssize_t n = sock_.Send(data, size);
if (n == -1) {
support::Socket::Error("SockChannel::Send");
}
return static_cast<size_t>(n);
}
size_t Recv(void* data, size_t size) final {
ssize_t n = sock_.Recv(data, size);
if (n == -1) {
support::Socket::Error("SockChannel::Recv");
}
return static_cast<size_t>(n);
}
private:
support::TCPSocket sock_;
};
std::shared_ptr<RPCEndpoint> RPCConnect(std::string url, int port, std::string key,
bool enable_logging, ffi::PackedArgs init_seq) {
support::TCPSocket sock;
support::SockAddr addr(url.c_str(), port);
sock.Create(addr.ss_family());
TVM_FFI_ICHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed";
// hand shake
std::ostringstream os;
int code = kRPCMagic;
int keylen = static_cast<int>(key.length());
TVM_FFI_ICHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
TVM_FFI_ICHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
if (keylen != 0) {
TVM_FFI_ICHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen);
}
TVM_FFI_ICHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code == kRPCMagic + 2) {
sock.Close();
TVM_FFI_THROW(InternalError) << "URL " << url << ":" << port
<< " cannot find server that matches key=" << key;
} else if (code == kRPCMagic + 1) {
sock.Close();
TVM_FFI_THROW(InternalError) << "URL " << url << ":" << port
<< " server already have key=" << key;
} else if (code != kRPCMagic) {
sock.Close();
TVM_FFI_THROW(InternalError) << "URL " << url << ":" << port << " is not TVM RPC server";
}
TVM_FFI_ICHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
std::string remote_key;
if (keylen != 0) {
remote_key.resize(keylen);
TVM_FFI_ICHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
}
std::unique_ptr<RPCChannel> channel = std::make_unique<SockChannel>(sock);
auto endpt = RPCEndpoint::Create(std::move(channel), key, remote_key);
endpt->InitRemoteSession(init_seq);
return endpt;
}
ffi::Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging,
ffi::PackedArgs init_seq) {
auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq);
return CreateRPCSessionModule(CreateClientSession(endpt));
}
// TVM_DLL needed for MSVC
TVM_DLL void RPCServerLoop(int sockfd) {
support::TCPSocket sock(static_cast<support::TCPSocket::SockType>(sockfd));
RPCEndpoint::Create(std::make_unique<SockChannel>(sock), "SockServerLoop", "")->ServerLoop();
}
void RPCServerLoop(ffi::Function fsend, ffi::Function frecv) {
RPCEndpoint::Create(std::make_unique<CallbackChannel>(fsend, frecv), "SockServerLoop", "")
->ServerLoop();
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("rpc.Connect",
[](ffi::PackedArgs args, ffi::Any* rv) {
auto url = args[0].cast<std::string>();
int port = args[1].cast<int>();
auto key = args[2].cast<std::string>();
bool enable_logging = args[3].cast<bool>();
*rv = RPCClientConnect(url, port, key, enable_logging, args.Slice(4));
})
.def_packed("rpc.ServerLoop", [](ffi::PackedArgs args, ffi::Any* rv) {
if (auto opt_int = args[0].as<int64_t>()) {
RPCServerLoop(opt_int.value());
} else {
RPCServerLoop(args[0].cast<tvm::ffi::Function>(), args[1].cast<tvm::ffi::Function>());
}
});
}
class SimpleSockHandler : public support::Stream {
// Things that will interface with user directly.
public:
explicit SimpleSockHandler(int sockfd)
: sock_(static_cast<support::TCPSocket::SockType>(sockfd)) {}
using support::Stream::Read;
using support::Stream::ReadArray;
using support::Stream::Write;
using support::Stream::WriteArray;
// Unused here, implemented for microTVM framing layer.
void MessageStart(uint64_t packet_nbytes) {}
void MessageDone() {}
// Internal supporting.
// Override methods that inherited from support::Stream.
private:
size_t Read(void* data, size_t size) final { return sock_.Recv(data, size); }
size_t Write(const void* data, size_t size) final { return sock_.Send(data, size); }
// Things of current class.
private:
support::TCPSocket sock_;
};
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("rpc.ReturnException", [](int sockfd, ffi::String msg) {
auto handler = SimpleSockHandler(sockfd);
RPCReference::ReturnException(msg.c_str(), &handler);
return;
});
}
} // namespace runtime
} // namespace tvm