blob: a94c45fc94ce57da3fa8cbc8b3cf08aa1acf197f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_server.cc
* \brief RPC Server for TVM.
*/
#include <csignal>
#include <cstdio>
#include <cstdlib>
#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h>
#endif
#include <dmlc/logging.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <vector>
#include "../../src/support/socket.h"
#include "../../src/support/utils.h"
#include "rpc_server.h"
#if defined(_WIN32)
#include "win32_process.h"
#endif
using namespace std;
using namespace tvm::runtime;
using namespace tvm::support;
static const string kUsage =
"Command line usage\n"
" server - Start the server\n"
"--host - The hostname of the server, Default=0.0.0.0\n"
"--port - The port of the RPC, Default=9090\n"
"--port-end - The end search port of the RPC, Default=9099\n"
"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n"
"--key - The key used to identify the device type in tracker. Default=\"\"\n"
"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n"
"--work-dir - Custom work directory. Default=\"\"\n"
"--silent - Whether to run in silent mode. Default=False\n"
"\n"
" Example\n"
" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 "
" --tracker=127.0.0.1:9190 --key=rasp"
"\n";
/*!
* \brief RpcServerArgs.
* \arg host The hostname of the server, Default=0.0.0.0
* \arg port The port of the RPC, Default=9090
* \arg port_end The end search port of the RPC, Default=9099
* \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \arg key The key used to identify the device type in tracker. Default=""
* \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \arg work_dir Custom work directory. Default=""
* \arg silent Whether run in silent mode. Default=False
*/
struct RpcServerArgs {
string host = "0.0.0.0";
int port = 9090;
int port_end = 9099;
string tracker;
string key;
string custom_addr;
string work_dir;
bool silent = false;
#if defined(WIN32)
std::string mmap_path;
#endif
};
/*!
* \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure
*/
void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end;
LOG(INFO) << "tracker = " << args.tracker;
LOG(INFO) << "key = " << args.key;
LOG(INFO) << "custom_addr = " << args.custom_addr;
LOG(INFO) << "work_dir = " << args.work_dir;
LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False"));
}
#if defined(__linux__) || defined(__ANDROID__)
/*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal
*/
void CtrlCHandler(int s) {
LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
exit(1);
}
/*!
* \brief HandleCtrlC Register for handling Ctrl+C event.
*/
void HandleCtrlC() {
// Ctrl+C handler
struct sigaction sigIntHandler;
sigIntHandler.sa_handler = CtrlCHandler;
sigemptyset(&sigIntHandler.sa_mask);
sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr);
}
#endif
/*!
* \brief GetCmdOption Parse and find the command option.
* \param argc arg counter
* \param argv arg values
* \param option command line option to search for.
* \param key whether the option itself is key
* \return value corresponding to option.
*/
string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
string cmd;
for (int i = 1; i < argc; ++i) {
string arg = argv[i];
if (arg.find(option) == 0) {
if (key) {
cmd = argv[i];
return cmd;
}
// We assume "=" is the end of option.
ICHECK_EQ(*option.rbegin(), '=');
cmd = arg.substr(arg.find('=') + 1);
return cmd;
}
}
return cmd;
}
/*!
* \brief ValidateTracker Check the tracker address format is correct and changes the format.
* \param tracker The tracker input.
* \return result of operation.
*/
bool ValidateTracker(string& tracker) {
vector<string> list = Split(tracker, ':');
if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) {
return false;
}
ostringstream ss;
ss << "('" << list[0] << "', " << list[1] << ")";
tracker = ss.str();
return true;
}
/*!
* \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter
* \param argv arg values
* \param args the output structure which holds the parsed values
*/
void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) {
const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) {
args.silent = true;
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
const string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) {
if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format.";
LOG(INFO) << kUsage;
exit(1);
}
args.host = host;
}
const string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number.";
LOG(INFO) << kUsage;
exit(1);
}
args.port = stoi(port);
}
const string port_end = GetCmdOption(argc, argv, "--port-end=");
if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port-end number.";
LOG(INFO) << kUsage;
exit(1);
}
args.port_end = stoi(port_end);
}
string tracker = GetCmdOption(argc, argv, "--tracker=");
if (!tracker.empty()) {
if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format.";
LOG(INFO) << kUsage;
exit(1);
}
args.tracker = tracker;
}
const string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) {
args.key = key;
}
const string custom_addr = GetCmdOption(argc, argv, "--custom-addr=");
if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format.";
LOG(INFO) << kUsage;
exit(1);
}
args.custom_addr = custom_addr;
}
#if defined(WIN32)
const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
if (!mmap_path.empty()) {
args.mmap_path = mmap_path;
dmlc::InitLogging("--minloglevel=0");
}
#endif
const string work_dir = GetCmdOption(argc, argv, "--work-dir=");
if (!work_dir.empty()) {
args.work_dir = work_dir;
}
}
/*!
* \brief RpcServer Starts the RPC server.
* \param argc arg counter
* \param argv arg values
* \return result of operation.
*/
int RpcServer(int argc, char* argv[]) {
RpcServerArgs args;
/* parse the command line args */
ParseCmdArgs(argc, argv, args);
PrintArgs(args);
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
#if defined(__linux__) || defined(__ANDROID__)
// Ctrl+C handler
HandleCtrlC();
#endif
#if defined(WIN32)
if (!args.mmap_path.empty()) {
int ret = 0;
try {
ChildProcSocketHandler(args.mmap_path);
} catch (const std::exception&) {
ret = -1;
}
return ret;
}
#endif
RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr,
args.work_dir, args.silent);
return 0;
}
/*!
* \brief main The main function.
* \param argc arg counter
* \param argv arg values
* \return result of operation.
*/
int main(int argc, char* argv[]) {
if (argc <= 1) {
LOG(INFO) << kUsage;
return 0;
}
// Runs WSAStartup on Win32, no-op on POSIX
Socket::Startup();
#if defined(_WIN32)
SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1");
#endif
if (0 == strcmp(argv[1], "server")) {
return RpcServer(argc, argv);
}
LOG(INFO) << kUsage;
return 0;
}