blob: 083ace2bf1683ae592b3a592809fc338e91bb6de [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef WIN32
#include "MiNiFiWindowsService.h"
#include <Windows.h>
#include <Strsafe.h>
#include <tuple>
#include <tlhelp32.h>
#include "MainHelper.h"
#include "core/FlowConfiguration.h"
// Implemented in MiNiFiMain.cpp
void SignalExitProcess();
static const char* const SERVICE_TERMINATION_EVENT_NAME = "Global\\MiNiFiServiceTermination";
static void OutputDebug(const char* format, ...) {
va_list args;
va_start(args, format);
char buf[256];
sprintf_s(buf, _countof(buf), "%s: %s", SERVICE_NAME, format);
char out[1024];
StringCbVPrintfA(out, sizeof(out), buf, args);
OutputDebugStringA(out);
va_end(args);
};
void RunAsServiceIfNeeded() {
static const int WAIT_TIME_EXE_TERMINATION = 5000;
static const int WAIT_TIME_EXE_RESTART = 60000;
static SERVICE_STATUS s_serviceStatus;
static SERVICE_STATUS_HANDLE s_statusHandle;
static HANDLE s_hProcess;
static HANDLE s_hEvent;
static auto Log = []() {
static std::shared_ptr<logging::Logger> s_logger = logging::LoggerConfiguration::getConfiguration().getLogger("service");
return s_logger;
};
SERVICE_TABLE_ENTRY serviceTable[] = {
{
SERVICE_NAME,
[](DWORD argc, LPTSTR *argv) {
setSyslogLogger();
Log()->log_trace("ServiceCtrlDispatcher");
s_hEvent = CreateEvent(0, TRUE, FALSE, SERVICE_TERMINATION_EVENT_NAME);
if (!s_hEvent) {
Log()->log_error("!CreateEvent lastError %x", GetLastError());
return;
}
s_statusHandle = RegisterServiceCtrlHandler(
SERVICE_NAME,
[](DWORD ctrlCode) {
Log()->log_trace("ServiceCtrlHandler ctrlCode %d", ctrlCode);
if (SERVICE_CONTROL_STOP == ctrlCode) {
Log()->log_trace("ServiceCtrlHandler ctrlCode = SERVICE_CONTROL_STOP");
// Set service status SERVICE_STOP_PENDING.
s_serviceStatus.dwControlsAccepted = 0;
s_serviceStatus.dwCurrentState = SERVICE_STOP_PENDING;
s_serviceStatus.dwWin32ExitCode = 0;
if (!SetServiceStatus(s_statusHandle, &s_serviceStatus)) {
Log()->log_error("!SetServiceStatus SERVICE_STOP_PENDING lastError %x", GetLastError());
}
bool exeTerminated = false;
SetEvent(s_hEvent);
Log()->log_info("Wait for exe termination");
switch (auto res = WaitForSingleObject(s_hProcess, WAIT_TIME_EXE_TERMINATION)) {
case WAIT_OBJECT_0:
Log()->log_info("Exe terminated");
exeTerminated = true;
break;
case WAIT_TIMEOUT:
Log()->log_error("WaitForSingleObject timeout %d", WAIT_TIME_EXE_TERMINATION);
break;
default:
Log()->log_error("!WaitForSingleObject return %d", res);
}
if (!exeTerminated) {
Log()->log_info("TerminateProcess");
if (TerminateProcess(s_hProcess, 0)) {
s_serviceStatus.dwControlsAccepted = 0;
s_serviceStatus.dwCurrentState = SERVICE_STOPPED;
s_serviceStatus.dwWin32ExitCode = 0;
if (!SetServiceStatus(s_statusHandle, &s_serviceStatus)) {
Log()->log_error("!SetServiceStatus SERVICE_STOPPED lastError %x", GetLastError());
}
} else {
Log()->log_error("!TerminateProcess lastError %x", GetLastError());
}
}
}
}
);
if (!s_statusHandle) {
Log()->log_error("!RegisterServiceCtrlHandler lastError %x", GetLastError());
return;
}
// Set service status SERVICE_START_PENDING.
ZeroMemory(&s_serviceStatus, sizeof(s_serviceStatus));
s_serviceStatus.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
s_serviceStatus.dwControlsAccepted = 0;
s_serviceStatus.dwCurrentState = SERVICE_START_PENDING;
s_serviceStatus.dwWin32ExitCode = 0;
s_serviceStatus.dwServiceSpecificExitCode = 0;
if (!SetServiceStatus(s_statusHandle, &s_serviceStatus)) {
Log()->log_error("!SetServiceStatus SERVICE_START_PENDING lastError %x", GetLastError());
return;
}
// Set service status SERVICE_RUNNING.
s_serviceStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP;
s_serviceStatus.dwCurrentState = SERVICE_RUNNING;
s_serviceStatus.dwWin32ExitCode = 0;
if (!SetServiceStatus(s_statusHandle, &s_serviceStatus)) {
Log()->log_error("!SetServiceStatus SERVICE_RUNNING lastError %x", GetLastError());
// Set service status SERVICE_START_PENDING.
s_serviceStatus.dwControlsAccepted = 0;
s_serviceStatus.dwCurrentState = SERVICE_STOPPED;
s_serviceStatus.dwWin32ExitCode = GetLastError();
if (!SetServiceStatus(s_statusHandle, &s_serviceStatus)) {
Log()->log_error("!SetServiceStatus SERVICE_STOPPED lastError %x", GetLastError());
}
return;
}
char filePath[MAX_PATH];
if (!GetModuleFileName(0, filePath, _countof(filePath))) {
Log()->log_error("!GetModuleFileName lastError %x", GetLastError());
return;
}
do {
Log()->log_debug("Start exe path %s", filePath);
STARTUPINFO startupInfo{};
startupInfo.cb = sizeof(startupInfo);
PROCESS_INFORMATION processInformation{};
if (!CreateProcess(filePath, 0, 0, 0, 0, FALSE, 0, 0, &startupInfo, &processInformation)) {
Log()->log_error("!CreateProcess lastError %x", GetLastError());
return;
}
s_hProcess = processInformation.hProcess;
Log()->log_info("%s started", filePath);
auto res = WaitForSingleObject(processInformation.hProcess, INFINITE);
CloseHandle(processInformation.hProcess);
CloseHandle(processInformation.hThread);
if (WAIT_FAILED == res) {
Log()->log_error("!WaitForSingleObject hProcess lastError %x", GetLastError());
} else if (WAIT_OBJECT_0 != res) {
Log()->log_error("!WaitForSingleObject hProcess return %d", res);
}
Log()->log_info("Sleep %d sec before restarting exe", WAIT_TIME_EXE_RESTART/1000);
res = WaitForSingleObject(s_hEvent, WAIT_TIME_EXE_RESTART);
if (WAIT_OBJECT_0 == res) {
Log()->log_info("Service was stopped, exe won't be restarted");
break;
}
if (WAIT_FAILED == res) {
Log()->log_error("!WaitForSingleObject s_hEvent lastError %x", GetLastError());
} if (WAIT_TIMEOUT != res) {
Log()->log_error("!WaitForSingleObject s_hEvent return %d", res);
}
} while (true);
s_serviceStatus.dwControlsAccepted = 0;
s_serviceStatus.dwCurrentState = SERVICE_STOPPED;
s_serviceStatus.dwWin32ExitCode = 0;
if (!SetServiceStatus(s_statusHandle, &s_serviceStatus)) {
Log()->log_error("!SetServiceStatus SERVICE_STOPPED lastError %x", GetLastError());
}
}
},
{0, 0}
};
if (!StartServiceCtrlDispatcher(serviceTable)) {
if (ERROR_FAILED_SERVICE_CONTROLLER_CONNECT == GetLastError()) {
// Run this exe as console.
return;
}
Log()->log_error("!StartServiceCtrlDispatcher lastError %x", GetLastError());
ExitProcess(1);
}
Log()->log_info("Service exit");
ExitProcess(0);
}
HANDLE GetTerminationEventHandle(bool* isStartedByService) {
*isStartedByService = true;
HANDLE hEvent = CreateEvent(0, TRUE, FALSE, SERVICE_TERMINATION_EVENT_NAME);
if (!hEvent) {
return nullptr;
}
if (GetLastError() != ERROR_ALREADY_EXISTS) {
*isStartedByService = false;
}
return hEvent;
}
bool CreateServiceTerminationThread(std::shared_ptr<logging::Logger> logger, HANDLE terminationEventHandle) {
// Get hService and monitor it - if service is terminated, then terminate current exe, otherwise the exe becomes unmanageable when service is restarted.
auto hService = [&logger]() -> HANDLE {
auto hSnapShot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
if (INVALID_HANDLE_VALUE == hSnapShot) {
logger->log_error("!CreateToolhelp32Snapshot lastError %x", GetLastError());
return 0;
}
auto getProcessInfo = [&logger, &hSnapShot](DWORD processId, DWORD& parentProcessId, std::string& parentProcessName) {
parentProcessId = 0;
parentProcessName.clear();
PROCESSENTRY32 procentry{};
procentry.dwSize = sizeof(procentry);
if (!Process32First(hSnapShot, &procentry)) {
logger->log_error("!Process32First lastError %x", GetLastError());
return;
}
do {
if (processId == procentry.th32ProcessID) {
parentProcessId = procentry.th32ParentProcessID;
parentProcessName = procentry.szExeFile;
return;
}
} while (Process32Next(hSnapShot, &procentry));
};
// Find current process info, which contains parentProcessId.
DWORD parentProcessId{};
std::string parentProcessName;
getProcessInfo(GetCurrentProcessId(), parentProcessId, parentProcessName);
// Find parent process info (the service which started current process), which contains service name.
DWORD parentParentProcessId{};
getProcessInfo(parentProcessId, parentParentProcessId, parentProcessName);
CloseHandle(hSnapShot);
// Just in case check that service name == current process name.
char filePath[MAX_PATH];
if (!GetModuleFileName(0, filePath, _countof(filePath))) {
logger->log_error("!GetModuleFileName lastError %x", GetLastError());
return 0;
}
const auto pSlash = strrchr(filePath, '\\');
if (!pSlash) {
logger->log_error("Invalid filePath %s", filePath);
return 0;
}
const std::string fileName = pSlash + 1;
if (_stricmp(fileName.c_str(), parentProcessName.c_str())) {
logger->log_error("Parent process %s != current process %s", parentProcessName.c_str(), fileName.c_str());
return 0;
}
const auto hParentProcess = OpenProcess(SYNCHRONIZE, FALSE, parentProcessId);
if (!hParentProcess) {
logger->log_error("!OpenProcess lastError %x", GetLastError());
return 0;
}
return hParentProcess;
}();
if (!hService)
return false;
using ThreadInfo = std::tuple<std::shared_ptr<logging::Logger>, HANDLE, HANDLE>;
auto pThreadInfo = new ThreadInfo(logger, terminationEventHandle, hService);
HANDLE hThread = (HANDLE)_beginthreadex(
0, 0,
[](void* pPar) {
const auto pThreadInfo = static_cast<ThreadInfo*>(pPar);
const auto logger = std::get<0>(*pThreadInfo);
const auto terminationEventHandle = std::get<1>(*pThreadInfo);
const auto hService = std::get<2>(*pThreadInfo);
delete pThreadInfo;
HANDLE arHandle[] = { terminationEventHandle, hService };
switch (auto res = WaitForMultipleObjects(_countof(arHandle), arHandle, FALSE, INFINITE)) {
case WAIT_FAILED:
logger->log_error("!WaitForSingleObject lastError %x", GetLastError());
break;
case WAIT_OBJECT_0:
logger->log_info("Service event received");
break;
case WAIT_OBJECT_0 + 1:
logger->log_info("Service is terminated");
break;
default:
logger->log_info("WaitForMultipleObjects return %d", res);
}
SignalExitProcess();
return 0U;
},
pThreadInfo,
0, 0);
if (!hThread) {
logger->log_error("!_beginthreadex lastError %x", GetLastError());
delete pThreadInfo;
return false;
}
return true;
}
#endif