| /** |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "c2/ControllerSocketProtocol.h" |
| |
| #include <fstream> |
| #include <sstream> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "asio/detached.hpp" |
| #include "asio/ssl/stream.hpp" |
| #include "c2/C2Payload.h" |
| #include "c2/C2Utils.h" |
| #include "controllers/SSLContextService.h" |
| #include "io/AsioStream.h" |
| #include "properties/Configuration.h" |
| #include "utils/ConfigurationUtils.h" |
| #include "utils/StringUtils.h" |
| #include "minifi-cpp/utils/gsl.h" |
| #include "utils/net/AsioSocketUtils.h" |
| |
| namespace org::apache::nifi::minifi::c2 { |
| |
| ControllerSocketProtocol::SocketRestartCommandProcessor::SocketRestartCommandProcessor(state::StateMonitor& update_sink, const std::shared_ptr<core::logging::Logger>& logger) : |
| update_sink_(update_sink), |
| logger_(logger) { |
| command_queue_.start(); |
| command_processor_thread_ = std::thread([this] { |
| while (running_) { |
| CommandData command_data; |
| if (command_queue_.dequeueWait(command_data)) { |
| if (command_data.command == Command::FLOW_UPDATE) { |
| auto result = update_sink_.applyUpdate("ControllerSocketProtocol", command_data.data, true); |
| if (!result) { |
| logger_->log_error("Failed to apply flow update: {}", result.error()); |
| } |
| } else if (command_data.command == Command::START) { |
| update_sink_.executeOnComponent(command_data.data, [](state::StateController& component) { |
| component.start(); |
| }); |
| } |
| } |
| is_socket_restarting_ = false; |
| } |
| }); |
| } |
| |
| ControllerSocketProtocol::SocketRestartCommandProcessor::~SocketRestartCommandProcessor() { |
| running_ = false; |
| command_queue_.stop(); |
| if (command_processor_thread_.joinable()) { |
| command_processor_thread_.join(); |
| } |
| } |
| |
| ControllerSocketProtocol::ControllerSocketProtocol(state::StateMonitor& update_sink, std::shared_ptr<Configure> configuration, |
| const std::shared_ptr<ControllerSocketReporter>& controller_socket_reporter) |
| : update_sink_(update_sink), |
| controller_socket_reporter_(controller_socket_reporter), |
| configuration_(std::move(configuration)), |
| socket_restart_processor_(update_sink_, logger_) { |
| gsl_Expects(configuration_); |
| } |
| |
| ControllerSocketProtocol::~ControllerSocketProtocol() { |
| stopListener(); |
| } |
| |
| void ControllerSocketProtocol::stopListener() { |
| if (acceptor_) { |
| asio::post(io_context_, [this] { |
| acceptor_->close(); |
| }); |
| } |
| if (server_thread_.joinable()) { |
| server_thread_.join(); |
| } |
| io_context_.restart(); |
| } |
| |
| asio::awaitable<void> ControllerSocketProtocol::startAccept() { |
| while (true) { |
| auto [accept_error, socket] = co_await acceptor_->async_accept(utils::net::use_nothrow_awaitable); |
| if (accept_error) { |
| if (accept_error == asio::error::operation_aborted || accept_error == asio::error::bad_descriptor) { |
| logger_->log_debug("Controller socket accept aborted"); |
| co_return; |
| } |
| logger_->log_error("Controller socket accept failed with the following message: '{}'", accept_error.message()); |
| continue; |
| } |
| auto stream = std::make_unique<io::AsioStream<asio::ip::tcp::socket>>(std::move(socket)); |
| co_spawn(io_context_, handleCommand(std::move(stream)), asio::detached); |
| } |
| } |
| |
| asio::awaitable<void> ControllerSocketProtocol::handshakeAndHandleCommand(asio::ip::tcp::socket&& socket, std::shared_ptr<minifi::controllers::SSLContextServiceInterface> ssl_context_service) { |
| asio::ssl::context ssl_context = utils::net::getSslContext(*ssl_context_service, asio::ssl::context::tls_server); |
| ssl_context.set_options(utils::net::MINIFI_SSL_OPTIONS); |
| asio::ssl::stream<asio::ip::tcp::socket> ssl_socket(std::move(socket), ssl_context); |
| |
| auto [handshake_error] = co_await ssl_socket.async_handshake(utils::net::HandshakeType::server, utils::net::use_nothrow_awaitable); |
| if (handshake_error) { |
| logger_->log_error("Controller socket handshake failed with the following message: '{}'", handshake_error.message()); |
| co_return; |
| } |
| |
| auto stream = std::make_unique<io::AsioStream<asio::ssl::stream<asio::ip::tcp::socket>>>(std::move(ssl_socket)); |
| co_return co_await handleCommand(std::move(stream)); |
| } |
| |
| asio::awaitable<void> ControllerSocketProtocol::startAcceptSsl(std::shared_ptr<minifi::controllers::SSLContextServiceInterface> ssl_context_service) { |
| while (true) { // NOLINT(clang-analyzer-core.NullDereference) suppressing asio library linter warning |
| auto [accept_error, socket] = co_await acceptor_->async_accept(utils::net::use_nothrow_awaitable); |
| if (accept_error) { |
| if (accept_error == asio::error::operation_aborted || accept_error == asio::error::bad_descriptor) { |
| logger_->log_debug("Controller socket accept aborted"); |
| co_return; |
| } |
| logger_->log_error("Controller socket accept failed with the following message: '{}'", accept_error.message()); |
| continue; |
| } |
| |
| co_spawn(io_context_, handshakeAndHandleCommand(std::move(socket), ssl_context_service), asio::detached); |
| } |
| } |
| |
| void ControllerSocketProtocol::initialize() { |
| std::unique_lock<std::mutex> lock(initialization_mutex_); |
| std::shared_ptr<minifi::controllers::SSLContextServiceInterface> secure_context; |
| std::string secure_str; |
| if (configuration_->get(Configure::nifi_remote_input_secure, secure_str) && org::apache::nifi::minifi::utils::string::toBool(secure_str).value_or(false)) { |
| secure_context = std::make_shared<minifi::controllers::SSLContextService>("ControllerSocketProtocolSSL", configuration_); |
| secure_context->onEnable(); |
| } |
| |
| std::string limit_str; |
| const bool any_interface = configuration_->get(Configuration::controller_socket_local_any_interface, limit_str) && utils::string::toBool(limit_str).value_or(false); |
| |
| // if host name isn't defined we will use localhost |
| std::string host = "localhost"; |
| configuration_->get(Configuration::controller_socket_host, host); |
| |
| std::string port; |
| stopListener(); |
| if (configuration_->get(Configuration::controller_socket_port, port)) { |
| // if we have a localhost hostname and we did not manually specify any.interface we will bind only to the loopback adapter |
| if ((host == "localhost" || host == "127.0.0.1" || host == "::") && !any_interface) { |
| acceptor_ = std::make_unique<asio::ip::tcp::acceptor>(io_context_, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), std::stoi(port))); |
| } else { |
| acceptor_ = std::make_unique<asio::ip::tcp::acceptor>(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v4(), std::stoi(port))); |
| } |
| |
| if (secure_context) { |
| co_spawn(io_context_, startAcceptSsl(std::move(secure_context)), asio::detached); |
| } else { |
| co_spawn(io_context_, startAccept(), asio::detached); |
| } |
| server_thread_ = std::thread([this] { |
| io_context_.run(); |
| }); |
| } |
| } |
| |
| void ControllerSocketProtocol::setRoot(core::ProcessGroup* root) { |
| if (auto controller_socket_reporter = controller_socket_reporter_.lock()) { |
| controller_socket_reporter->setRoot(root); |
| } |
| } |
| |
| void ControllerSocketProtocol::handleStart(io::BaseStream &stream) { |
| std::string component_str; |
| const auto size = stream.read(component_str); |
| if (!io::isError(size)) { |
| if (component_str == "FlowController") { |
| // Starting flow controller resets socket |
| socket_restart_processor_.enqueue({SocketRestartCommandProcessor::Command::START, component_str}); |
| } else { |
| update_sink_.executeOnComponent(component_str, [](state::StateController& component) { |
| component.start(); |
| }); |
| } |
| } else { |
| logger_->log_error("Connection broke"); |
| } |
| } |
| |
| void ControllerSocketProtocol::handleStop(io::BaseStream &stream) { |
| std::string component_str; |
| const auto size = stream.read(component_str); |
| if (!io::isError(size)) { |
| update_sink_.executeOnComponent(component_str, [](state::StateController& component) { |
| component.stop(); |
| }); |
| } else { |
| logger_->log_error("Connection broke"); |
| } |
| } |
| |
| void ControllerSocketProtocol::handleClear(io::BaseStream &stream) { |
| std::string connection; |
| const auto size = stream.read(connection); |
| if (!io::isError(size)) { |
| update_sink_.clearConnection(connection); |
| } |
| } |
| |
| void ControllerSocketProtocol::handleUpdate(io::BaseStream &stream) { |
| std::string what; |
| { |
| const auto size = stream.read(what); |
| if (io::isError(size)) { |
| logger_->log_error("Connection broke"); |
| return; |
| } |
| } |
| if (what == "flow") { |
| std::string ff_loc; |
| { |
| const auto size = stream.read(ff_loc); |
| if (io::isError(size)) { |
| logger_->log_error("Connection broke"); |
| return; |
| } |
| } |
| std::ifstream tf(ff_loc); |
| std::string flow_configuration((std::istreambuf_iterator<char>(tf)), |
| std::istreambuf_iterator<char>()); |
| socket_restart_processor_.enqueue({SocketRestartCommandProcessor::Command::FLOW_UPDATE, flow_configuration}); |
| } |
| } |
| |
| void ControllerSocketProtocol::writeQueueSizesResponse(io::BaseStream &stream) { |
| std::string connection; |
| const auto size_ = stream.read(connection); |
| if (io::isError(size_)) { |
| logger_->log_error("Connection broke"); |
| return; |
| } |
| std::unordered_map<std::string, ControllerSocketReporter::QueueSize> sizes; |
| if (auto controller_socket_reporter = controller_socket_reporter_.lock()) { |
| sizes = controller_socket_reporter->getQueueSizes(); |
| } |
| std::stringstream response; |
| if (sizes.contains(connection)) { |
| response << sizes[connection].queue_size << " / " << sizes[connection].queue_size_max; |
| } else { |
| response << "not found"; |
| } |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::describe); |
| resp.write(&op, 1); |
| resp.write(response.str()); |
| stream.write(resp.getBuffer()); |
| } |
| |
| void ControllerSocketProtocol::writeComponentsResponse(io::BaseStream &stream) { |
| std::vector<std::pair<std::string, bool>> components; |
| update_sink_.executeOnAllComponents([&components](state::StateController& component) { |
| components.emplace_back(component.getComponentName(), component.isRunning()); |
| }); |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::describe); |
| resp.write(&op, 1); |
| resp.write(gsl::narrow<uint16_t>(components.size())); |
| for (const auto& [name, is_running] : components) { |
| resp.write(name); |
| resp.write(is_running ? "true" : "false"); |
| } |
| |
| stream.write(resp.getBuffer()); |
| } |
| |
| void ControllerSocketProtocol::writeConnectionsResponse(io::BaseStream &stream) { |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::describe); |
| resp.write(&op, 1); |
| std::unordered_set<std::string> connections; |
| if (auto controller_socket_reporter = controller_socket_reporter_.lock()) { |
| connections = controller_socket_reporter->getConnections(); |
| } |
| |
| const auto size = gsl::narrow<uint16_t>(connections.size()); |
| resp.write(size); |
| for (const auto &connection : connections) { |
| resp.write(connection, false); |
| } |
| stream.write(resp.getBuffer()); |
| } |
| |
| void ControllerSocketProtocol::writeGetFullResponse(io::BaseStream &stream) { |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::describe); |
| resp.write(&op, 1); |
| std::unordered_set<std::string> full_connections; |
| if (auto controller_socket_reporter = controller_socket_reporter_.lock()) { |
| full_connections = controller_socket_reporter->getFullConnections(); |
| } |
| |
| const auto size = gsl::narrow<uint16_t>(full_connections.size()); |
| resp.write(size); |
| for (const auto &connection : full_connections) { |
| resp.write(connection, false); |
| } |
| stream.write(resp.getBuffer()); |
| } |
| |
| void ControllerSocketProtocol::writeManifestResponse(io::BaseStream &stream) { |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::describe); |
| resp.write(&op, 1); |
| std::string manifest; |
| if (auto controller_socket_reporter = controller_socket_reporter_.lock()) { |
| manifest = controller_socket_reporter->getAgentManifest(); |
| } |
| resp.write(manifest, true); |
| stream.write(resp.getBuffer()); |
| } |
| |
| std::string ControllerSocketProtocol::getJstack() { |
| if (!update_sink_.isRunning()) { |
| return {}; |
| } |
| std::stringstream result; |
| const auto traces = update_sink_.getTraces(); |
| for (const auto& trace : traces) { |
| for (const auto& line : trace.getTraces()) { |
| result << trace.getName() << " -- " << line << "\n"; |
| } |
| } |
| return result.str(); |
| } |
| |
| void ControllerSocketProtocol::writeJstackResponse(io::BaseStream &stream) { |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::describe); |
| resp.write(&op, 1); |
| std::string jstack_response; |
| if (auto controller_socket_reporter = controller_socket_reporter_.lock()) { |
| jstack_response = getJstack(); |
| } |
| resp.write(jstack_response, true); |
| stream.write(resp.getBuffer()); |
| } |
| |
| void ControllerSocketProtocol::writeFlowStatusResponse(io::BaseStream &stream) { |
| std::string query; |
| { |
| const auto size = stream.read(query); |
| if (io::isError(size)) { |
| logger_->log_error("Connection broke"); |
| return; |
| } |
| } |
| auto request_strings = utils::string::splitAndTrimRemovingEmpty(query, ";"); |
| std::vector<FlowStatusRequest> requests; |
| for (const auto& request_string : request_strings) { |
| try { |
| requests.push_back(FlowStatusRequest(request_string)); |
| } catch (const std::exception& e) { |
| logger_->log_error("Invalid flow status request: {}", e.what()); |
| } |
| } |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::describe); |
| resp.write(&op, 1); |
| std::string flowstatus_response; |
| if (auto controller_socket_reporter = controller_socket_reporter_.lock()) { |
| flowstatus_response = controller_socket_reporter->getFlowStatus(requests); |
| } |
| resp.write(flowstatus_response, true); |
| stream.write(resp.getBuffer()); |
| } |
| |
| void ControllerSocketProtocol::handleDescribe(io::BaseStream &stream) { |
| std::string what; |
| const auto size = stream.read(what); |
| if (io::isError(size)) { |
| logger_->log_error("Connection broke"); |
| return; |
| } |
| if (what == "queue") { |
| writeQueueSizesResponse(stream); |
| } else if (what == "components") { |
| writeComponentsResponse(stream); |
| } else if (what == "connections") { |
| writeConnectionsResponse(stream); |
| } else if (what == "getfull") { |
| writeGetFullResponse(stream); |
| } else if (what == "manifest") { |
| writeManifestResponse(stream); |
| } else if (what == "jstack") { |
| writeJstackResponse(stream); |
| } else if (what == "flowstatus") { |
| writeFlowStatusResponse(stream); |
| } else { |
| logger_->log_error("Unknown C2 describe parameter: {}", what); |
| } |
| } |
| |
| void ControllerSocketProtocol::writeDebugBundleResponse(io::BaseStream &stream) { |
| auto files = update_sink_.getDebugInfo(); |
| auto bundle = createDebugBundleArchive(files); |
| io::BufferStream resp; |
| auto op = static_cast<uint8_t>(Operation::transfer); |
| resp.write(&op, 1); |
| if (!bundle) { |
| logger_->log_error("Creating debug bundle failed: {}", bundle.error()); |
| resp.write(static_cast<size_t>(0)); |
| stream.write(resp.getBuffer()); |
| return; |
| } |
| |
| size_t bundle_size = bundle.value()->size(); |
| resp.write(bundle_size); |
| static constexpr auto BUFFER_SIZE = utils::configuration::DEFAULT_BUFFER_SIZE; |
| std::array<std::byte, BUFFER_SIZE> out_buffer{}; |
| while (bundle_size > 0) { |
| const auto next_write_size = (std::min)(bundle_size, BUFFER_SIZE); |
| const auto size_read = bundle.value()->read(std::as_writable_bytes(std::span(out_buffer).subspan(0, next_write_size))); |
| resp.write(reinterpret_cast<const uint8_t*>(out_buffer.data()), size_read); |
| bundle_size -= size_read; |
| } |
| |
| stream.write(resp.getBuffer()); |
| } |
| |
| void ControllerSocketProtocol::handleTransfer(io::BaseStream &stream) { |
| std::string what; |
| const auto size = stream.read(what); |
| if (io::isError(size)) { |
| logger_->log_error("Connection broke"); |
| return; |
| } |
| if (what == "debug") { |
| writeDebugBundleResponse(stream); |
| } else { |
| logger_->log_error("Unknown C2 transfer parameter: {}", what); |
| } |
| } |
| |
| asio::awaitable<void> ControllerSocketProtocol::handleCommand(std::unique_ptr<io::BaseStream> stream) { |
| uint8_t head = 0; |
| if (stream->read(head) != 1) { |
| logger_->log_error("Connection broke"); |
| co_return; |
| } |
| |
| if (socket_restart_processor_.isSocketRestarting()) { |
| logger_->log_debug("Socket restarting, dropping command"); |
| co_return; |
| } |
| |
| auto op = static_cast<Operation>(head); |
| switch (op) { |
| case Operation::start: |
| handleStart(*stream); |
| break; |
| case Operation::stop: |
| handleStop(*stream); |
| break; |
| case Operation::clear: |
| handleClear(*stream); |
| break; |
| case Operation::update: |
| handleUpdate(*stream); |
| break; |
| case Operation::describe: |
| handleDescribe(*stream); |
| break; |
| case Operation::transfer: |
| handleTransfer(*stream); |
| break; |
| default: |
| logger_->log_error("Unhandled C2 operation: {}", head); |
| } |
| } |
| |
| } // namespace org::apache::nifi::minifi::c2 |