blob: 217baf133bf17015f0c3827b377a238a40089501 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_server.cc
* \brief RPC Server implementation.
*/
#include <tvm/runtime/registry.h>
#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__)
#include <signal.h>
#include <sys/select.h>
#include <sys/wait.h>
#endif
#include <chrono>
#include <future>
#include <iostream>
#include <set>
#include <string>
#include <thread>
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/runtime/rpc/rpc_socket_impl.h"
#include "../../src/support/socket.h"
#include "rpc_env.h"
#include "rpc_server.h"
#include "rpc_tracker_client.h"
#if defined(_WIN32)
#include "win32_process.h"
#endif
using namespace std::chrono;
namespace tvm {
namespace runtime {
/*!
* \brief wait the child process end.
* \param status status value
*/
#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__)
static pid_t waitPidEintr(int* status) {
pid_t pid = 0;
while ((pid = waitpid(-1, status, 0)) == -1) {
if (errno == EINTR) {
continue;
} else {
perror("waitpid");
abort();
}
}
return pid;
}
#endif
#ifdef __ANDROID__
static std::string getNextString(std::stringstream* iss) {
std::string str = iss->str();
size_t start = iss->tellg();
size_t len = str.size();
// Skip leading spaces.
while (start < len && isspace(str[start])) start++;
size_t end = start;
while (end < len && !isspace(str[end])) end++;
iss->seekg(end);
return str.substr(start, end - start);
}
#endif
/*!
* \brief RPCServer RPC Server class.
*
* \param host The hostname of the server, Default=0.0.0.0
*
* \param port_search_start The low end of the search range for an
* available port for the RPC, Default=9090
*
* \param port_search_end The high search the search range for an
* available port for the RPC, Default=9099
*
* \param tracker The address of RPC tracker in host:port format
* (e.g. "10.77.1.234:9190")
*
* \param key The key used to identify the device type in tracker.
*
* \param custom_addr Custom IP Address to Report to RPC Tracker.
*/
class RPCServer {
public:
/*!
* \brief Constructor.
*/
RPCServer(std::string host, int port_search_start, int port_search_end, std::string tracker_addr,
std::string key, std::string custom_addr, std::string work_dir)
: host_(std::move(host)),
port_search_start_(port_search_start),
my_port_(0),
port_search_end_(port_search_end),
tracker_addr_(std::move(tracker_addr)),
key_(std::move(key)),
custom_addr_(std::move(custom_addr)),
work_dir_(std::move(work_dir)) {}
/*!
* \brief Destructor.
*/
~RPCServer() {
try {
// Free the resources
tracker_sock_.Close();
listen_sock_.Close();
} catch (...) {
}
}
/*!
* \brief Start Creates the RPC listen process and execution.
*/
void Start() {
listen_sock_.Create();
my_port_ = listen_sock_.TryBindHost(host_, port_search_start_, port_search_end_);
LOG(INFO) << "bind to " << host_ << ":" << my_port_;
listen_sock_.Listen(1);
std::future<void> proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this));
proc.get();
// Close the listen socket
listen_sock_.Close();
}
private:
/*!
* \brief ListenLoopProc The listen process.
*/
void ListenLoopProc() {
TrackerClient tracker(tracker_addr_, key_, custom_addr_, my_port_);
while (true) {
support::TCPSocket conn;
support::SockAddr addr("0.0.0.0", 0);
std::string opts;
try {
// step 1: setup tracker and report to tracker
tracker.TryConnect();
// step 2: wait for in-coming connections
AcceptConnection(&tracker, &conn, &addr, &opts);
} catch (const char* msg) {
LOG(WARNING) << "Socket exception: " << msg;
// close tracker resource
tracker.Close();
continue;
} catch (const std::exception& e) {
// close tracker resource
tracker.Close();
LOG(WARNING) << "Exception standard: " << e.what();
continue;
}
int timeout = GetTimeOutFromOpts(opts);
#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__)
// step 3: serving
if (timeout != 0) {
const pid_t timer_pid = fork();
if (timer_pid == 0) {
// Timer process
sleep(timeout);
_exit(0);
}
const pid_t worker_pid = fork();
if (worker_pid == 0) {
// Worker process
ServerLoopProc(conn, addr, work_dir_);
_exit(0);
}
int status = 0;
const pid_t finished_first = waitPidEintr(&status);
if (finished_first == timer_pid) {
kill(worker_pid, SIGTERM);
} else if (finished_first == worker_pid) {
kill(timer_pid, SIGTERM);
} else {
LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue.";
}
int status_second = 0;
waitPidEintr(&status_second);
// Logging.
if (finished_first == timer_pid) {
LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout
<< "), Process status = " << status_second;
} else if (finished_first == worker_pid) {
LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second;
}
} else {
auto pid = fork();
if (pid == 0) {
ServerLoopProc(conn, addr, work_dir_);
_exit(0);
}
// Wait for the result
int status = 0;
wait(&status);
LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status;
}
#elif defined(WIN32)
auto start_time = high_resolution_clock::now();
try {
SpawnRPCChild(conn.sockfd, seconds(timeout));
} catch (const std::exception&) {
}
auto dur = high_resolution_clock::now() - start_time;
LOG(INFO) << "Serve Time " << duration_cast<milliseconds>(dur).count() << "ms";
#else
LOG(WARNING) << "Unknown platform. It is not known how to bring up the subprocess."
<< " RPC will be launched in the main thread.";
ServerLoopProc(conn, addr, work_dir_);
#endif
// close from our side.
LOG(INFO) << "Socket Connection Closed";
conn.Close();
}
}
/*!
* \brief AcceptConnection Accepts the RPC Server connection.
* \param tracker Tracker details.
* \param conn_sock New connection information.
* \param addr New connection address information.
* \param opts Parsed options for socket
* \param ping_period Timeout for select call waiting
*/
void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock,
support::SockAddr* addr, std::string* opts, int ping_period = 2) {
std::set<std::string> old_keyset;
std::string matchkey;
// Report resource to tracker and get key
tracker->ReportResourceAndGetKey(my_port_, &matchkey);
while (true) {
tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey);
support::TCPSocket conn = listen_sock_.Accept(addr);
int code = kRPCMagic;
ICHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
conn.Close();
LOG(FATAL) << "Client connected is not TVM RPC server";
continue;
}
int keylen = 0;
ICHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
const char* CLIENT_HEADER = "client:";
const char* SERVER_HEADER = "server:";
std::string expect_header = CLIENT_HEADER + matchkey;
std::string server_key = SERVER_HEADER + key_;
if (size_t(keylen) < expect_header.length()) {
conn.Close();
LOG(INFO) << "Wrong client header length";
continue;
}
ICHECK_NE(keylen, 0);
std::string remote_key;
remote_key.resize(keylen);
ICHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen);
std::stringstream ssin(remote_key);
std::string arg0;
#ifndef __ANDROID__
ssin >> arg0;
#else
arg0 = getNextString(&ssin);
#endif
if (arg0 != expect_header) {
code = kRPCMismatch;
ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
conn.Close();
LOG(WARNING) << "Mismatch key from" << addr->AsString();
continue;
} else {
code = kRPCSuccess;
ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
keylen = int(server_key.length());
ICHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
ICHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen);
LOG(INFO) << "Connection success " << addr->AsString();
#ifndef __ANDROID__
ssin >> *opts;
#else
*opts = getNextString(&ssin);
#endif
*conn_sock = conn;
return;
}
}
}
/*!
* \brief ServerLoopProc The Server loop process.
* \param sock The socket information
* \param addr The socket address information
*/
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr,
std::string work_dir) {
// Server loop
const auto env = RPCEnv(work_dir);
RPCServerLoop(int(sock.sockfd));
LOG(INFO) << "Finish serving " << addr.AsString();
env.CleanUp();
}
/*!
* \brief GetTimeOutFromOpts Parse and get the timeout option.
* \param opts The option string
*/
int GetTimeOutFromOpts(const std::string& opts) const {
const std::string option = "-timeout=";
size_t pos = opts.rfind(option);
if (pos != std::string::npos) {
const std::string cmd = opts.substr(pos + option.size());
ICHECK(support::IsNumber(cmd)) << "Timeout is not valid";
return std::stoi(cmd);
}
return 0;
}
std::string host_;
int port_search_start_;
int my_port_;
int port_search_end_;
std::string tracker_addr_;
std::string key_;
std::string custom_addr_;
std::string work_dir_;
support::TCPSocket listen_sock_;
support::TCPSocket tracker_sock_;
};
#if defined(WIN32)
/*!
* \brief ServerLoopFromChild The Server loop process.
* \param socket The socket information
*/
void ServerLoopFromChild(SOCKET socket) {
// Server loop
tvm::support::TCPSocket sock(socket);
const auto env = RPCEnv();
RPCServerLoop(int(sock.sockfd));
sock.Close();
env.CleanUp();
}
#endif
/*!
* \brief RPCServerCreate Creates the RPC Server.
* \param host The hostname of the server, Default=0.0.0.0
* \param port The port of the RPC, Default=9090
* \param port_end The end search port of the RPC, Default=9099
* \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190
* Default="" \param key The key used to identify the device type in tracker. Default="" \param
* custom_addr Custom IP Address to Report to RPC Tracker. Default="" \param silent Whether run in
* silent mode. Default=True
*/
void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
std::string key, std::string custom_addr, std::string work_dir, bool silent) {
if (silent) {
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
// Start the rpc server
RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key),
std::move(custom_addr), std::move(work_dir));
rpc.Start();
}
TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
});
} // namespace runtime
} // namespace tvm