blob: bbf8367903bb22f64f2113878b986e84b6c96c29 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#include "win32_process.h"
#include <conio.h>
#include <dmlc/logging.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <cstdio>
#include <memory>
#include <stdexcept>
#include <string>
#include "rpc_server.h"
using namespace std::chrono;
using namespace tvm::runtime;
namespace {
// The prefix path for the memory mapped file used to store IPC information
const std::string kMemoryMapPrefix = "/MAPPED_FILE/TVM_RPC";
// Used to construct unique names for named resources in the parent process
const std::string kParent = "parent";
// Used to construct unique names for named resources in the child process
const std::string kChild = "child";
// The timeout of the WIN32 events, in the parent and the child
const milliseconds kEventTimeout(2000);
// Used to create unique WIN32 mmap paths and event names
int child_counter_ = 0;
/*!
* \brief HandleDeleter Deleter for UniqueHandle smart pointer
* \param handle The WIN32 HANDLE to manage
*/
struct HandleDeleter {
void operator()(HANDLE handle) const {
if (handle != INVALID_HANDLE_VALUE && handle != nullptr) {
CloseHandle(handle);
}
}
};
/*!
* \brief UniqueHandle Smart pointer to manage a WIN32 HANDLE
*/
using UniqueHandle = std::unique_ptr<void, HandleDeleter>;
/*!
* \brief MakeUniqueHandle Helper method to construct a UniqueHandle
* \param handle The WIN32 HANDLE to manage
*/
UniqueHandle MakeUniqueHandle(HANDLE handle) {
if (handle == INVALID_HANDLE_VALUE || handle == nullptr) {
return nullptr;
}
return UniqueHandle(handle);
}
/*!
* \brief GetSocket Gets the socket info from the parent process and duplicates the socket
* \param mmap_path The path to the memory mapped info set by the parent
*/
SOCKET GetSocket(const std::string& mmap_path) {
WSAPROTOCOL_INFO protocol_info;
const std::string parent_event_name = mmap_path + kParent;
const std::string child_event_name = mmap_path + kChild;
// Open the events
UniqueHandle parent_file_mapping_event;
if ((parent_file_mapping_event = MakeUniqueHandle(
OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
}
UniqueHandle child_file_mapping_event;
if ((child_file_mapping_event = MakeUniqueHandle(
OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
}
// Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read
if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) !=
WAIT_OBJECT_0) {
LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError();
}
const UniqueHandle file_map =
MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, false, mmap_path.c_str()));
if (!file_map) {
LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
}
void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
SOCKET sock_duplicated = INVALID_SOCKET;
if (map_view != nullptr) {
memcpy(&protocol_info, map_view, sizeof(WSAPROTOCOL_INFO));
UnmapViewOfFile(map_view);
// Creates the duplicate socket, that was created in the parent
sock_duplicated =
WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &protocol_info, 0, 0);
// Let the parent know we are finished dupicating the socket
SetEvent(child_file_mapping_event.get());
} else {
LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError();
}
return sock_duplicated;
}
} // Anonymous namespace
namespace tvm {
namespace runtime {
/*!
* \brief SpawnRPCChild Spawns a child process with a given timeout to run
* \param fd The client socket to duplicate in the child
* \param timeout The time in seconds to wait for the child to complete before termination
*/
void SpawnRPCChild(SOCKET fd, seconds timeout) {
STARTUPINFOA startup_info;
memset(&startup_info, 0, sizeof(startup_info));
startup_info.cb = sizeof(startup_info);
std::string file_map_path = kMemoryMapPrefix + std::to_string(child_counter_++);
const std::string parent_event_name = file_map_path + kParent;
const std::string child_event_name = file_map_path + kChild;
// Create an event to let the child know the socket info was set to the mmap file
UniqueHandle parent_file_mapping_event;
if ((parent_file_mapping_event = MakeUniqueHandle(
CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "CreateEvent for parent file mapping failed";
}
UniqueHandle child_file_mapping_event;
// An event to let the parent know the socket info was read from the mmap file
if ((child_file_mapping_event = MakeUniqueHandle(
CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "CreateEvent for child file mapping failed";
}
char current_executable[MAX_PATH];
// Get the full path of the current executable
GetModuleFileNameA(nullptr, current_executable, MAX_PATH);
std::string child_command_line = current_executable;
child_command_line += " server --child_proc=";
child_command_line += file_map_path;
// CreateProcessA requires a non const char*, so we copy our std::string
std::unique_ptr<char[]> command_line_ptr(new char[child_command_line.size() + 1]);
strcpy(command_line_ptr.get(), child_command_line.c_str());
PROCESS_INFORMATION child_process_info;
if (CreateProcessA(nullptr, command_line_ptr.get(), nullptr, nullptr, false, CREATE_NO_WINDOW,
nullptr, nullptr, &startup_info, &child_process_info)) {
// Child process and thread handles must be closed, so wrapped in RAII
auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess);
auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread);
WSAPROTOCOL_INFO protocol_info;
// Get info needed to duplicate the socket
if (WSADuplicateSocket(fd, child_process_info.dwProcessId, &protocol_info) == SOCKET_ERROR) {
LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError();
}
// Create a mmap file to store the info needed for duplicating the SOCKET in the child proc
UniqueHandle file_map =
MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, 0,
sizeof(WSAPROTOCOL_INFO), file_map_path.c_str()));
if (!file_map) {
LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
}
if (GetLastError() == ERROR_ALREADY_EXISTS) {
LOG(FATAL) << "CreateFileMapping(): mapping file already exists";
} else {
void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
if (map_view != nullptr) {
memcpy(map_view, &protocol_info, sizeof(WSAPROTOCOL_INFO));
UnmapViewOfFile(map_view);
// Let child proc know the mmap file is ready to be read
SetEvent(parent_file_mapping_event.get());
// Wait for the child to finish reading mmap file
if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) !=
WAIT_OBJECT_0) {
TerminateProcess(child_process_handle.get(), 0);
LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child "
"process.";
}
} else {
TerminateProcess(child_process_handle.get(), 0);
LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError();
}
}
const DWORD process_timeout =
timeout.count() ? uint32_t(duration_cast<milliseconds>(timeout).count()) : INFINITE;
// Wait for child process to exit, or hit configured timeout
if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) {
LOG(INFO) << "Child process timeout. Terminating.";
TerminateProcess(child_process_handle.get(), 0);
}
} else {
LOG(INFO) << "Create child process failed: " << GetLastError();
}
}
/*!
* \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client
* socket \param mmap_path The memory mapped file path that will contain the information to
* duplicate the client socket from the parent
*/
void ChildProcSocketHandler(const std::string& mmap_path) {
SOCKET socket;
// Set high thread priority to avoid the thread scheduler from
// interfering with any measurements in the RPC server.
SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL);
if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) {
tvm::runtime::ServerLoopFromChild(socket);
} else {
LOG(FATAL) << "GetSocket() failed";
}
}
} // namespace runtime
} // namespace tvm