| /** |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #pragma once |
| |
| #include <asio/read.hpp> |
| |
| #include "utils/net/AsioCoro.h" |
| #include "utils/net/AsioSocketUtils.h" |
| #include "ConnectionHandlerBase.h" |
| |
| namespace org::apache::nifi::minifi::utils::net { |
| |
| template<class SocketType> |
| class ConnectionHandler final : public ConnectionHandlerBase { |
| public: |
| ConnectionHandler(ConnectionId connection_id, |
| const std::chrono::milliseconds timeout, |
| std::shared_ptr<core::logging::Logger> logger, |
| const std::optional<size_t> max_size_of_socket_send_buffer, |
| asio::ssl::context* ssl_context) |
| : connection_id_(std::move(connection_id)), |
| timeout_duration_(timeout), |
| logger_(std::move(logger)), |
| max_size_of_socket_send_buffer_(max_size_of_socket_send_buffer), |
| ssl_context_(ssl_context) { |
| } |
| |
| ConnectionHandler(ConnectionHandler&&) = delete; |
| ConnectionHandler(const ConnectionHandler&) = delete; |
| ConnectionHandler& operator=(ConnectionHandler&&) = delete; |
| ConnectionHandler& operator=(const ConnectionHandler&) = delete; |
| |
| ~ConnectionHandler() override { |
| shutdownSocket(); |
| } |
| |
| |
| private: |
| [[nodiscard]] bool hasBeenUsedIn(std::chrono::milliseconds dur) const override { |
| return last_used_ && *last_used_ >= (std::chrono::steady_clock::now() - dur); |
| } |
| |
| void reset() override { |
| last_used_.reset(); |
| socket_.reset(); |
| } |
| |
| [[nodiscard]] bool hasBeenUsed() const override { return last_used_.has_value(); } |
| [[nodiscard]] asio::awaitable<std::error_code> setupUsableSocket(asio::io_context& io_context) override; |
| [[nodiscard]] bool hasUsableSocket() const { return socket_ && socket_->lowest_layer().is_open(); } |
| |
| asio::awaitable<std::error_code> establishNewConnection(const asio::ip::tcp::resolver::results_type& endpoints, asio::io_context& io_context_); |
| [[nodiscard]] asio::awaitable<std::tuple<std::error_code, size_t>> write(const asio::const_buffer& buffer) override; |
| [[nodiscard]] asio::awaitable<std::tuple<std::error_code, size_t>> read(asio::mutable_buffer& buffer) override; |
| |
| SocketType createNewSocket(asio::io_context& io_context_); |
| void shutdownSocket(); |
| |
| ConnectionId connection_id_; |
| std::optional<SocketType> socket_{}; |
| |
| std::optional<std::chrono::steady_clock::time_point> last_used_{}; |
| asio::steady_timer::duration timeout_duration_{}; |
| |
| std::shared_ptr<core::logging::Logger> logger_{}; |
| std::optional<size_t> max_size_of_socket_send_buffer_{}; |
| |
| asio::ssl::context* ssl_context_{}; |
| }; |
| |
| template<> |
| inline TcpSocket ConnectionHandler<TcpSocket>::createNewSocket(asio::io_context& io_context_) { |
| gsl_Expects(!ssl_context_); |
| return TcpSocket{io_context_}; |
| } |
| |
| template<> |
| inline SslSocket ConnectionHandler<SslSocket>::createNewSocket(asio::io_context& io_context_) { |
| gsl_Expects(ssl_context_); |
| return {io_context_, *ssl_context_}; |
| } |
| |
| template<> |
| inline void ConnectionHandler<TcpSocket>::shutdownSocket() { |
| } |
| |
| template<> |
| inline void ConnectionHandler<SslSocket>::shutdownSocket() { |
| gsl_Expects(ssl_context_); |
| if (socket_) { |
| asio::error_code ec; |
| socket_->lowest_layer().cancel(ec); |
| if (ec) { |
| logger_->log_error("Cancelling asynchronous operations of SSL socket failed with: {}", ec.message()); |
| } |
| socket_->shutdown(ec); |
| if (ec) { |
| logger_->log_error("Shutdown of SSL socket failed with: {}", ec.message()); |
| } |
| } |
| } |
| |
| template<class SocketType> |
| asio::awaitable<std::error_code> ConnectionHandler<SocketType>::establishNewConnection(const asio::ip::tcp::resolver::results_type& endpoints, asio::io_context& io_context) { |
| auto socket = createNewSocket(io_context); |
| std::error_code last_error; |
| for (const auto& endpoint : endpoints) { |
| auto [connection_error] = co_await asyncOperationWithTimeout(socket.lowest_layer().async_connect(endpoint, use_nothrow_awaitable), timeout_duration_); |
| if (connection_error) { |
| logger_->log_debug("Connecting to {} failed due to {}", endpoint.endpoint(), connection_error.message()); |
| last_error = connection_error; |
| continue; |
| } |
| auto [handshake_error] = co_await handshake(socket, timeout_duration_); |
| if (handshake_error) { |
| logger_->log_debug("Handshake with {} failed due to {}", endpoint.endpoint(), handshake_error.message()); |
| last_error = handshake_error; |
| continue; |
| } |
| if (max_size_of_socket_send_buffer_) |
| socket.lowest_layer().set_option(TcpSocket::send_buffer_size(gsl::narrow<int>(*max_size_of_socket_send_buffer_))); |
| socket_.emplace(std::move(socket)); |
| co_return std::error_code(); |
| } |
| co_return last_error; |
| } |
| |
| template<class SocketType> |
| [[nodiscard]] asio::awaitable<std::error_code> ConnectionHandler<SocketType>::setupUsableSocket(asio::io_context& io_context) { |
| if (hasUsableSocket()) |
| co_return std::error_code(); |
| asio::ip::tcp::resolver resolver(io_context); |
| auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout( |
| resolver.async_resolve(connection_id_.getHostname(), connection_id_.getService(), use_nothrow_awaitable), timeout_duration_); |
| if (resolve_error) |
| co_return resolve_error; |
| co_return co_await establishNewConnection(resolve_result, io_context); |
| } |
| |
| template<class SocketType> |
| asio::awaitable<std::tuple<std::error_code, size_t>> ConnectionHandler<SocketType>::write(const asio::const_buffer& buffer) { |
| auto result = co_await asyncOperationWithTimeout(asio::async_write(*socket_, buffer, use_nothrow_awaitable), timeout_duration_); |
| if (!std::get<std::error_code>(result)) { |
| last_used_ = std::chrono::steady_clock::now(); |
| } |
| co_return result; |
| } |
| |
| template<class SocketType> |
| asio::awaitable<std::tuple<std::error_code, size_t>> ConnectionHandler<SocketType>::read(asio::mutable_buffer& buffer) { |
| auto result = co_await asyncOperationWithTimeout(asio::async_read(*socket_, buffer, use_nothrow_awaitable), timeout_duration_); |
| if (!std::get<std::error_code>(result)) { |
| last_used_ = std::chrono::steady_clock::now(); |
| } |
| co_return result; |
| } |
| |
| } // namespace org::apache::nifi::minifi::utils::net |