| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file rpc_server.cc |
| * \brief RPC Server for TVM. |
| */ |
| #include <csignal> |
| #include <cstdio> |
| #include <cstdlib> |
| #if defined(__linux__) || defined(__ANDROID__) |
| #include <unistd.h> |
| #endif |
| #include <dmlc/logging.h> |
| |
| #include <cstring> |
| #include <iostream> |
| #include <sstream> |
| #include <vector> |
| |
| #include "../../src/support/socket.h" |
| #include "../../src/support/utils.h" |
| #include "rpc_server.h" |
| |
| #if defined(_WIN32) |
| #include "win32_process.h" |
| #endif |
| |
| using namespace std; |
| using namespace tvm::runtime; |
| using namespace tvm::support; |
| |
| static const string kUsage = |
| "Command line usage\n" |
| " server - Start the server\n" |
| "--host - The hostname of the server, Default=0.0.0.0\n" |
| "--port - The port of the RPC, Default=9090\n" |
| "--port-end - The end search port of the RPC, Default=9099\n" |
| "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" |
| "--key - The key used to identify the device type in tracker. Default=\"\"\n" |
| "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" |
| "--work-dir - Custom work directory. Default=\"\"\n" |
| "--silent - Whether to run in silent mode. Default=False\n" |
| "\n" |
| " Example\n" |
| " ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " |
| " --tracker=127.0.0.1:9190 --key=rasp" |
| "\n"; |
| |
| /*! |
| * \brief RpcServerArgs. |
| * \arg host The hostname of the server, Default=0.0.0.0 |
| * \arg port The port of the RPC, Default=9090 |
| * \arg port_end The end search port of the RPC, Default=9099 |
| * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" |
| * \arg key The key used to identify the device type in tracker. Default="" |
| * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" |
| * \arg work_dir Custom work directory. Default="" |
| * \arg silent Whether run in silent mode. Default=False |
| */ |
| struct RpcServerArgs { |
| string host = "0.0.0.0"; |
| int port = 9090; |
| int port_end = 9099; |
| string tracker; |
| string key; |
| string custom_addr; |
| string work_dir; |
| bool silent = false; |
| #if defined(WIN32) |
| std::string mmap_path; |
| #endif |
| }; |
| |
| /*! |
| * \brief PrintArgs print the contents of RpcServerArgs |
| * \param args RpcServerArgs structure |
| */ |
| void PrintArgs(const RpcServerArgs& args) { |
| LOG(INFO) << "host = " << args.host; |
| LOG(INFO) << "port = " << args.port; |
| LOG(INFO) << "port_end = " << args.port_end; |
| LOG(INFO) << "tracker = " << args.tracker; |
| LOG(INFO) << "key = " << args.key; |
| LOG(INFO) << "custom_addr = " << args.custom_addr; |
| LOG(INFO) << "work_dir = " << args.work_dir; |
| LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False")); |
| } |
| |
| #if defined(__linux__) || defined(__ANDROID__) |
| /*! |
| * \brief CtrlCHandler, exits if Ctrl+C is pressed |
| * \param s signal |
| */ |
| void CtrlCHandler(int s) { |
| LOG(INFO) << "\nUser pressed Ctrl+C, Exiting"; |
| exit(1); |
| } |
| |
| /*! |
| * \brief HandleCtrlC Register for handling Ctrl+C event. |
| */ |
| void HandleCtrlC() { |
| // Ctrl+C handler |
| struct sigaction sigIntHandler; |
| sigIntHandler.sa_handler = CtrlCHandler; |
| sigemptyset(&sigIntHandler.sa_mask); |
| sigIntHandler.sa_flags = 0; |
| sigaction(SIGINT, &sigIntHandler, nullptr); |
| } |
| #endif |
| /*! |
| * \brief GetCmdOption Parse and find the command option. |
| * \param argc arg counter |
| * \param argv arg values |
| * \param option command line option to search for. |
| * \param key whether the option itself is key |
| * \return value corresponding to option. |
| */ |
| string GetCmdOption(int argc, char* argv[], string option, bool key = false) { |
| string cmd; |
| for (int i = 1; i < argc; ++i) { |
| string arg = argv[i]; |
| if (arg.find(option) == 0) { |
| if (key) { |
| cmd = argv[i]; |
| return cmd; |
| } |
| // We assume "=" is the end of option. |
| ICHECK_EQ(*option.rbegin(), '='); |
| cmd = arg.substr(arg.find('=') + 1); |
| return cmd; |
| } |
| } |
| return cmd; |
| } |
| |
| /*! |
| * \brief ValidateTracker Check the tracker address format is correct and changes the format. |
| * \param tracker The tracker input. |
| * \return result of operation. |
| */ |
| bool ValidateTracker(string& tracker) { |
| vector<string> list = Split(tracker, ':'); |
| if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) { |
| return false; |
| } |
| ostringstream ss; |
| ss << "('" << list[0] << "', " << list[1] << ")"; |
| tracker = ss.str(); |
| return true; |
| } |
| |
| /*! |
| * \brief ParseCmdArgs parses the command line arguments. |
| * \param argc arg counter |
| * \param argv arg values |
| * \param args the output structure which holds the parsed values |
| */ |
| void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) { |
| const string silent = GetCmdOption(argc, argv, "--silent", true); |
| if (!silent.empty()) { |
| args.silent = true; |
| // Only errors and fatal is logged |
| dmlc::InitLogging("--minloglevel=2"); |
| } |
| |
| const string host = GetCmdOption(argc, argv, "--host="); |
| if (!host.empty()) { |
| if (!ValidateIP(host)) { |
| LOG(WARNING) << "Wrong host address format."; |
| LOG(INFO) << kUsage; |
| exit(1); |
| } |
| args.host = host; |
| } |
| |
| const string port = GetCmdOption(argc, argv, "--port="); |
| if (!port.empty()) { |
| if (!IsNumber(port) || stoi(port) > 65535) { |
| LOG(WARNING) << "Wrong port number."; |
| LOG(INFO) << kUsage; |
| exit(1); |
| } |
| args.port = stoi(port); |
| } |
| |
| const string port_end = GetCmdOption(argc, argv, "--port-end="); |
| if (!port_end.empty()) { |
| if (!IsNumber(port_end) || stoi(port_end) > 65535) { |
| LOG(WARNING) << "Wrong port-end number."; |
| LOG(INFO) << kUsage; |
| exit(1); |
| } |
| args.port_end = stoi(port_end); |
| } |
| |
| string tracker = GetCmdOption(argc, argv, "--tracker="); |
| if (!tracker.empty()) { |
| if (!ValidateTracker(tracker)) { |
| LOG(WARNING) << "Wrong tracker address format."; |
| LOG(INFO) << kUsage; |
| exit(1); |
| } |
| args.tracker = tracker; |
| } |
| |
| const string key = GetCmdOption(argc, argv, "--key="); |
| if (!key.empty()) { |
| args.key = key; |
| } |
| |
| const string custom_addr = GetCmdOption(argc, argv, "--custom-addr="); |
| if (!custom_addr.empty()) { |
| if (!ValidateIP(custom_addr)) { |
| LOG(WARNING) << "Wrong custom address format."; |
| LOG(INFO) << kUsage; |
| exit(1); |
| } |
| args.custom_addr = custom_addr; |
| } |
| #if defined(WIN32) |
| const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); |
| if (!mmap_path.empty()) { |
| args.mmap_path = mmap_path; |
| dmlc::InitLogging("--minloglevel=0"); |
| } |
| #endif |
| const string work_dir = GetCmdOption(argc, argv, "--work-dir="); |
| if (!work_dir.empty()) { |
| args.work_dir = work_dir; |
| } |
| } |
| |
| /*! |
| * \brief RpcServer Starts the RPC server. |
| * \param argc arg counter |
| * \param argv arg values |
| * \return result of operation. |
| */ |
| int RpcServer(int argc, char* argv[]) { |
| RpcServerArgs args; |
| |
| /* parse the command line args */ |
| ParseCmdArgs(argc, argv, args); |
| PrintArgs(args); |
| |
| LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; |
| #if defined(__linux__) || defined(__ANDROID__) |
| // Ctrl+C handler |
| HandleCtrlC(); |
| #endif |
| |
| #if defined(WIN32) |
| if (!args.mmap_path.empty()) { |
| int ret = 0; |
| |
| try { |
| ChildProcSocketHandler(args.mmap_path); |
| } catch (const std::exception&) { |
| ret = -1; |
| } |
| |
| return ret; |
| } |
| #endif |
| |
| RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr, |
| args.work_dir, args.silent); |
| return 0; |
| } |
| |
| /*! |
| * \brief main The main function. |
| * \param argc arg counter |
| * \param argv arg values |
| * \return result of operation. |
| */ |
| int main(int argc, char* argv[]) { |
| if (argc <= 1) { |
| LOG(INFO) << kUsage; |
| return 0; |
| } |
| |
| // Runs WSAStartup on Win32, no-op on POSIX |
| Socket::Startup(); |
| #if defined(_WIN32) |
| SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1"); |
| #endif |
| |
| if (0 == strcmp(argv[1], "server")) { |
| return RpcServer(argc, argv); |
| } |
| |
| LOG(INFO) << kUsage; |
| |
| return 0; |
| } |