blob: 1c51fa26572a703e4d22438ebf60f418a7ca9c39 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/net/TcpServer.h"
#include "utils/net/AsioCoro.h"
#include "utils/net/AsioSocketUtils.h"
using namespace std::literals::chrono_literals;
namespace org::apache::nifi::minifi::utils::net {
asio::awaitable<void> TcpServer::doReceive() {
asio::ip::tcp::acceptor acceptor(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v6(), port_));
if (port_ == 0)
port_ = acceptor.local_endpoint().port();
while (true) {
auto [accept_error, socket] = co_await acceptor.async_accept(use_nothrow_awaitable);
if (accept_error) {
logger_->log_error("Error during accepting new connection: {}", accept_error.message());
co_await utils::net::async_wait(1s);
continue;
}
std::error_code error;
auto remote_address = socket.lowest_layer().remote_endpoint(error).address();
if (error)
logger_->log_warn("Error during fetching remote endpoint: {}", error.message());
auto local_port = socket.lowest_layer().local_endpoint(error).port();
if (error)
logger_->log_warn("Error during fetching local endpoint: {}", error.message());
if (ssl_data_)
co_spawn(io_context_, secureSession(std::move(socket), std::move(remote_address), local_port), asio::detached);
else
co_spawn(io_context_, insecureSession(std::move(socket), std::move(remote_address), local_port), asio::detached);
}
}
asio::awaitable<void> TcpServer::readLoop(auto& socket, const auto& remote_address, const auto& local_port) {
std::string read_message;
while (true) {
auto [read_error, bytes_read] = co_await asio::async_read_until(socket, asio::dynamic_buffer(read_message), delimiter_, use_nothrow_awaitable); // NOLINT
if (read_error) {
if (read_error != asio::error::eof) {
logger_->log_error("Error during reading from socket: {}", read_error.message());
}
co_return;
}
if (bytes_read == 0) {
logger_->log_debug("No more bytes were read from socket");
co_return;
}
if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size()) {
auto message_str = read_message.substr(0, bytes_read - (consume_delimiter_ ? delimiter_.size() : 0));
concurrent_queue_.enqueue(Message(std::move(message_str), IpProtocol::TCP, remote_address, local_port));
} else {
logger_->log_warn("Queue is full. TCP message ignored.");
}
read_message.erase(0, bytes_read);
}
}
asio::awaitable<void> TcpServer::insecureSession(asio::ip::tcp::socket socket, asio::ip::address remote_address, asio::ip::port_type local_port) {
co_return co_await readLoop(socket, remote_address, local_port); // NOLINT
}
namespace {
asio::ssl::context setupSslContext(SslServerOptions& ssl_data) {
asio::ssl::context ssl_context(asio::ssl::context::tls_server);
ssl_context.set_options(minifi::utils::net::MINIFI_SSL_OPTIONS);
ssl_context.set_password_callback([key_pw = ssl_data.cert_data.key_pw](std::size_t&, asio::ssl::context_base::password_purpose&) { return key_pw; });
ssl_context.use_certificate_file(ssl_data.cert_data.cert_loc.string(), asio::ssl::context::pem);
ssl_context.use_private_key_file(ssl_data.cert_data.key_loc.string(), asio::ssl::context::pem);
if (!ssl_data.cert_data.ca_loc.empty())
ssl_context.load_verify_file(ssl_data.cert_data.ca_loc.string());
if (ssl_data.client_auth_option == ClientAuthOption::REQUIRED) {
ssl_context.set_verify_mode(asio::ssl::verify_peer|asio::ssl::verify_fail_if_no_peer_cert);
} else if (ssl_data.client_auth_option == ClientAuthOption::WANT) {
ssl_context.set_verify_mode(asio::ssl::verify_peer);
}
return ssl_context;
}
} // namespace
asio::awaitable<void> TcpServer::secureSession(asio::ip::tcp::socket socket, asio::ip::address remote_address, asio::ip::port_type local_port) {
gsl_Expects(ssl_data_);
auto ssl_context = setupSslContext(*ssl_data_);
SslSocket ssl_socket(std::move(socket), ssl_context);
auto [handshake_error] = co_await ssl_socket.async_handshake(HandshakeType::server, use_nothrow_awaitable);
if (handshake_error) {
logger_->log_warn("Handshake with {} failed due to {}", remote_address, handshake_error.message());
co_return;
}
co_await readLoop(ssl_socket, remote_address, local_port); // NOLINT
asio::error_code ec;
std::ignore = ssl_socket.lowest_layer().cancel(ec);
if (ec) {
logger_->log_error("Cancelling asynchronous operations of SSL socket failed with: {}", ec.message());
}
auto [shutdown_error] = co_await ssl_socket.async_shutdown(use_nothrow_awaitable);
if (shutdown_error) {
logger_->log_warn("Shutdown of {} failed with {}", remote_address, shutdown_error.message());
}
}
} // namespace org::apache::nifi::minifi::utils::net