blob: cea10dcc571a105fae0b38a6f644bcd451eb7271 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c2/ControllerSocketProtocol.h"
#include <fstream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "asio/detached.hpp"
#include "asio/ssl/stream.hpp"
#include "c2/C2Payload.h"
#include "c2/C2Utils.h"
#include "controllers/SSLContextService.h"
#include "io/AsioStream.h"
#include "properties/Configuration.h"
#include "utils/ConfigurationUtils.h"
#include "utils/StringUtils.h"
#include "minifi-cpp/utils/gsl.h"
#include "utils/net/AsioSocketUtils.h"
namespace org::apache::nifi::minifi::c2 {
ControllerSocketProtocol::SocketRestartCommandProcessor::SocketRestartCommandProcessor(state::StateMonitor& update_sink, const std::shared_ptr<core::logging::Logger>& logger) :
update_sink_(update_sink),
logger_(logger) {
command_queue_.start();
command_processor_thread_ = std::thread([this] {
while (running_) {
CommandData command_data;
if (command_queue_.dequeueWait(command_data)) {
if (command_data.command == Command::FLOW_UPDATE) {
auto result = update_sink_.applyUpdate("ControllerSocketProtocol", command_data.data, true);
if (!result) {
logger_->log_error("Failed to apply flow update: {}", result.error());
}
} else if (command_data.command == Command::START) {
update_sink_.executeOnComponent(command_data.data, [](state::StateController& component) {
component.start();
});
}
}
is_socket_restarting_ = false;
}
});
}
ControllerSocketProtocol::SocketRestartCommandProcessor::~SocketRestartCommandProcessor() {
running_ = false;
command_queue_.stop();
if (command_processor_thread_.joinable()) {
command_processor_thread_.join();
}
}
ControllerSocketProtocol::ControllerSocketProtocol(state::StateMonitor& update_sink, std::shared_ptr<Configure> configuration,
const std::shared_ptr<ControllerSocketReporter>& controller_socket_reporter)
: update_sink_(update_sink),
controller_socket_reporter_(controller_socket_reporter),
configuration_(std::move(configuration)),
socket_restart_processor_(update_sink_, logger_) {
gsl_Expects(configuration_);
}
ControllerSocketProtocol::~ControllerSocketProtocol() {
stopListener();
}
void ControllerSocketProtocol::stopListener() {
if (acceptor_) {
asio::post(io_context_, [this] {
acceptor_->close();
});
}
if (server_thread_.joinable()) {
server_thread_.join();
}
io_context_.restart();
}
asio::awaitable<void> ControllerSocketProtocol::startAccept() {
while (true) {
auto [accept_error, socket] = co_await acceptor_->async_accept(utils::net::use_nothrow_awaitable);
if (accept_error) {
if (accept_error == asio::error::operation_aborted || accept_error == asio::error::bad_descriptor) {
logger_->log_debug("Controller socket accept aborted");
co_return;
}
logger_->log_error("Controller socket accept failed with the following message: '{}'", accept_error.message());
continue;
}
auto stream = std::make_unique<io::AsioStream<asio::ip::tcp::socket>>(std::move(socket));
co_spawn(io_context_, handleCommand(std::move(stream)), asio::detached);
}
}
asio::awaitable<void> ControllerSocketProtocol::handshakeAndHandleCommand(asio::ip::tcp::socket&& socket, std::shared_ptr<minifi::controllers::SSLContextServiceInterface> ssl_context_service) {
asio::ssl::context ssl_context = utils::net::getSslContext(*ssl_context_service, asio::ssl::context::tls_server);
ssl_context.set_options(utils::net::MINIFI_SSL_OPTIONS);
asio::ssl::stream<asio::ip::tcp::socket> ssl_socket(std::move(socket), ssl_context);
auto [handshake_error] = co_await ssl_socket.async_handshake(utils::net::HandshakeType::server, utils::net::use_nothrow_awaitable);
if (handshake_error) {
logger_->log_error("Controller socket handshake failed with the following message: '{}'", handshake_error.message());
co_return;
}
auto stream = std::make_unique<io::AsioStream<asio::ssl::stream<asio::ip::tcp::socket>>>(std::move(ssl_socket));
co_return co_await handleCommand(std::move(stream));
}
asio::awaitable<void> ControllerSocketProtocol::startAcceptSsl(std::shared_ptr<minifi::controllers::SSLContextServiceInterface> ssl_context_service) {
while (true) { // NOLINT(clang-analyzer-core.NullDereference) suppressing asio library linter warning
auto [accept_error, socket] = co_await acceptor_->async_accept(utils::net::use_nothrow_awaitable);
if (accept_error) {
if (accept_error == asio::error::operation_aborted || accept_error == asio::error::bad_descriptor) {
logger_->log_debug("Controller socket accept aborted");
co_return;
}
logger_->log_error("Controller socket accept failed with the following message: '{}'", accept_error.message());
continue;
}
co_spawn(io_context_, handshakeAndHandleCommand(std::move(socket), ssl_context_service), asio::detached);
}
}
void ControllerSocketProtocol::initialize() {
std::unique_lock<std::mutex> lock(initialization_mutex_);
std::shared_ptr<minifi::controllers::SSLContextServiceInterface> secure_context;
std::string secure_str;
if (configuration_->get(Configure::nifi_remote_input_secure, secure_str) && org::apache::nifi::minifi::utils::string::toBool(secure_str).value_or(false)) {
secure_context = std::make_shared<minifi::controllers::SSLContextService>("ControllerSocketProtocolSSL", configuration_);
secure_context->onEnable();
}
std::string limit_str;
const bool any_interface = configuration_->get(Configuration::controller_socket_local_any_interface, limit_str) && utils::string::toBool(limit_str).value_or(false);
// if host name isn't defined we will use localhost
std::string host = "localhost";
configuration_->get(Configuration::controller_socket_host, host);
std::string port;
stopListener();
if (configuration_->get(Configuration::controller_socket_port, port)) {
// if we have a localhost hostname and we did not manually specify any.interface we will bind only to the loopback adapter
if ((host == "localhost" || host == "127.0.0.1" || host == "::") && !any_interface) {
acceptor_ = std::make_unique<asio::ip::tcp::acceptor>(io_context_, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), std::stoi(port)));
} else {
acceptor_ = std::make_unique<asio::ip::tcp::acceptor>(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v4(), std::stoi(port)));
}
if (secure_context) {
co_spawn(io_context_, startAcceptSsl(std::move(secure_context)), asio::detached);
} else {
co_spawn(io_context_, startAccept(), asio::detached);
}
server_thread_ = std::thread([this] {
io_context_.run();
});
}
}
void ControllerSocketProtocol::setRoot(core::ProcessGroup* root) {
if (auto controller_socket_reporter = controller_socket_reporter_.lock()) {
controller_socket_reporter->setRoot(root);
}
}
void ControllerSocketProtocol::handleStart(io::BaseStream &stream) {
std::string component_str;
const auto size = stream.read(component_str);
if (!io::isError(size)) {
if (component_str == "FlowController") {
// Starting flow controller resets socket
socket_restart_processor_.enqueue({SocketRestartCommandProcessor::Command::START, component_str});
} else {
update_sink_.executeOnComponent(component_str, [](state::StateController& component) {
component.start();
});
}
} else {
logger_->log_error("Connection broke");
}
}
void ControllerSocketProtocol::handleStop(io::BaseStream &stream) {
std::string component_str;
const auto size = stream.read(component_str);
if (!io::isError(size)) {
update_sink_.executeOnComponent(component_str, [](state::StateController& component) {
component.stop();
});
} else {
logger_->log_error("Connection broke");
}
}
void ControllerSocketProtocol::handleClear(io::BaseStream &stream) {
std::string connection;
const auto size = stream.read(connection);
if (!io::isError(size)) {
update_sink_.clearConnection(connection);
}
}
void ControllerSocketProtocol::handleUpdate(io::BaseStream &stream) {
std::string what;
{
const auto size = stream.read(what);
if (io::isError(size)) {
logger_->log_error("Connection broke");
return;
}
}
if (what == "flow") {
std::string ff_loc;
{
const auto size = stream.read(ff_loc);
if (io::isError(size)) {
logger_->log_error("Connection broke");
return;
}
}
std::ifstream tf(ff_loc);
std::string flow_configuration((std::istreambuf_iterator<char>(tf)),
std::istreambuf_iterator<char>());
socket_restart_processor_.enqueue({SocketRestartCommandProcessor::Command::FLOW_UPDATE, flow_configuration});
}
}
void ControllerSocketProtocol::writeQueueSizesResponse(io::BaseStream &stream) {
std::string connection;
const auto size_ = stream.read(connection);
if (io::isError(size_)) {
logger_->log_error("Connection broke");
return;
}
std::unordered_map<std::string, ControllerSocketReporter::QueueSize> sizes;
if (auto controller_socket_reporter = controller_socket_reporter_.lock()) {
sizes = controller_socket_reporter->getQueueSizes();
}
std::stringstream response;
if (sizes.contains(connection)) {
response << sizes[connection].queue_size << " / " << sizes[connection].queue_size_max;
} else {
response << "not found";
}
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::describe);
resp.write(&op, 1);
resp.write(response.str());
stream.write(resp.getBuffer());
}
void ControllerSocketProtocol::writeComponentsResponse(io::BaseStream &stream) {
std::vector<std::pair<std::string, bool>> components;
update_sink_.executeOnAllComponents([&components](state::StateController& component) {
components.emplace_back(component.getComponentName(), component.isRunning());
});
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::describe);
resp.write(&op, 1);
resp.write(gsl::narrow<uint16_t>(components.size()));
for (const auto& [name, is_running] : components) {
resp.write(name);
resp.write(is_running ? "true" : "false");
}
stream.write(resp.getBuffer());
}
void ControllerSocketProtocol::writeConnectionsResponse(io::BaseStream &stream) {
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::describe);
resp.write(&op, 1);
std::unordered_set<std::string> connections;
if (auto controller_socket_reporter = controller_socket_reporter_.lock()) {
connections = controller_socket_reporter->getConnections();
}
const auto size = gsl::narrow<uint16_t>(connections.size());
resp.write(size);
for (const auto &connection : connections) {
resp.write(connection, false);
}
stream.write(resp.getBuffer());
}
void ControllerSocketProtocol::writeGetFullResponse(io::BaseStream &stream) {
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::describe);
resp.write(&op, 1);
std::unordered_set<std::string> full_connections;
if (auto controller_socket_reporter = controller_socket_reporter_.lock()) {
full_connections = controller_socket_reporter->getFullConnections();
}
const auto size = gsl::narrow<uint16_t>(full_connections.size());
resp.write(size);
for (const auto &connection : full_connections) {
resp.write(connection, false);
}
stream.write(resp.getBuffer());
}
void ControllerSocketProtocol::writeManifestResponse(io::BaseStream &stream) {
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::describe);
resp.write(&op, 1);
std::string manifest;
if (auto controller_socket_reporter = controller_socket_reporter_.lock()) {
manifest = controller_socket_reporter->getAgentManifest();
}
resp.write(manifest, true);
stream.write(resp.getBuffer());
}
std::string ControllerSocketProtocol::getJstack() {
if (!update_sink_.isRunning()) {
return {};
}
std::stringstream result;
const auto traces = update_sink_.getTraces();
for (const auto& trace : traces) {
for (const auto& line : trace.getTraces()) {
result << trace.getName() << " -- " << line << "\n";
}
}
return result.str();
}
void ControllerSocketProtocol::writeJstackResponse(io::BaseStream &stream) {
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::describe);
resp.write(&op, 1);
std::string jstack_response;
if (auto controller_socket_reporter = controller_socket_reporter_.lock()) {
jstack_response = getJstack();
}
resp.write(jstack_response, true);
stream.write(resp.getBuffer());
}
void ControllerSocketProtocol::writeFlowStatusResponse(io::BaseStream &stream) {
std::string query;
{
const auto size = stream.read(query);
if (io::isError(size)) {
logger_->log_error("Connection broke");
return;
}
}
auto request_strings = utils::string::splitAndTrimRemovingEmpty(query, ";");
std::vector<FlowStatusRequest> requests;
for (const auto& request_string : request_strings) {
try {
requests.push_back(FlowStatusRequest(request_string));
} catch (const std::exception& e) {
logger_->log_error("Invalid flow status request: {}", e.what());
}
}
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::describe);
resp.write(&op, 1);
std::string flowstatus_response;
if (auto controller_socket_reporter = controller_socket_reporter_.lock()) {
flowstatus_response = controller_socket_reporter->getFlowStatus(requests);
}
resp.write(flowstatus_response, true);
stream.write(resp.getBuffer());
}
void ControllerSocketProtocol::handleDescribe(io::BaseStream &stream) {
std::string what;
const auto size = stream.read(what);
if (io::isError(size)) {
logger_->log_error("Connection broke");
return;
}
if (what == "queue") {
writeQueueSizesResponse(stream);
} else if (what == "components") {
writeComponentsResponse(stream);
} else if (what == "connections") {
writeConnectionsResponse(stream);
} else if (what == "getfull") {
writeGetFullResponse(stream);
} else if (what == "manifest") {
writeManifestResponse(stream);
} else if (what == "jstack") {
writeJstackResponse(stream);
} else if (what == "flowstatus") {
writeFlowStatusResponse(stream);
} else {
logger_->log_error("Unknown C2 describe parameter: {}", what);
}
}
void ControllerSocketProtocol::writeDebugBundleResponse(io::BaseStream &stream) {
auto files = update_sink_.getDebugInfo();
auto bundle = createDebugBundleArchive(files);
io::BufferStream resp;
auto op = static_cast<uint8_t>(Operation::transfer);
resp.write(&op, 1);
if (!bundle) {
logger_->log_error("Creating debug bundle failed: {}", bundle.error());
resp.write(static_cast<size_t>(0));
stream.write(resp.getBuffer());
return;
}
size_t bundle_size = bundle.value()->size();
resp.write(bundle_size);
static constexpr auto BUFFER_SIZE = utils::configuration::DEFAULT_BUFFER_SIZE;
std::array<std::byte, BUFFER_SIZE> out_buffer{};
while (bundle_size > 0) {
const auto next_write_size = (std::min)(bundle_size, BUFFER_SIZE);
const auto size_read = bundle.value()->read(std::as_writable_bytes(std::span(out_buffer).subspan(0, next_write_size)));
resp.write(reinterpret_cast<const uint8_t*>(out_buffer.data()), size_read);
bundle_size -= size_read;
}
stream.write(resp.getBuffer());
}
void ControllerSocketProtocol::handleTransfer(io::BaseStream &stream) {
std::string what;
const auto size = stream.read(what);
if (io::isError(size)) {
logger_->log_error("Connection broke");
return;
}
if (what == "debug") {
writeDebugBundleResponse(stream);
} else {
logger_->log_error("Unknown C2 transfer parameter: {}", what);
}
}
asio::awaitable<void> ControllerSocketProtocol::handleCommand(std::unique_ptr<io::BaseStream> stream) {
uint8_t head = 0;
if (stream->read(head) != 1) {
logger_->log_error("Connection broke");
co_return;
}
if (socket_restart_processor_.isSocketRestarting()) {
logger_->log_debug("Socket restarting, dropping command");
co_return;
}
auto op = static_cast<Operation>(head);
switch (op) {
case Operation::start:
handleStart(*stream);
break;
case Operation::stop:
handleStop(*stream);
break;
case Operation::clear:
handleClear(*stream);
break;
case Operation::update:
handleUpdate(*stream);
break;
case Operation::describe:
handleDescribe(*stream);
break;
case Operation::transfer:
handleTransfer(*stream);
break;
default:
logger_->log_error("Unhandled C2 operation: {}", head);
}
}
} // namespace org::apache::nifi::minifi::c2