blob: 5e1e8a35947ff8ef4dbf60b94bd6e1b69e862be0 [file]
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <new>
#include <string>
#include "unit/SingleProcessorTestController.h"
#include "unit/Catch.h"
#include "PutTCP.h"
#include "controllers/SSLContextService.h"
#include "core/ProcessSession.h"
#include "utils/net/TcpServer.h"
#include "utils/net/AsioCoro.h"
#include "utils/expected.h"
#include "unit/TestUtils.h"
using namespace std::literals::chrono_literals;
using org::apache::nifi::minifi::test::utils::verifyLogLineVariantPresenceInPollTime;
using org::apache::nifi::minifi::test::utils::verifyEventHappenedInPollTime;
namespace org::apache::nifi::minifi::processors {
using controllers::SSLContextService;
namespace {
class CancellableTcpServer : public utils::net::TcpServer {
public:
using utils::net::TcpServer::TcpServer;
size_t getNumberOfSessions() const {
return cancellable_timers_.size();
}
void cancelEverything() {
for (auto& timer : cancellable_timers_)
asio::post(io_context_, [=]{timer->cancel();});
}
asio::awaitable<void> doReceive() override {
using asio::experimental::awaitable_operators::operator||;
asio::ip::tcp::acceptor acceptor(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v6(), port_));
if (port_ == 0)
port_ = acceptor.local_endpoint().port();
while (true) {
auto [accept_error, socket] = co_await acceptor.async_accept(utils::net::use_nothrow_awaitable);
if (accept_error) {
logger_->log_error("Error during accepting new connection: {}", accept_error.message());
break;
}
std::error_code error;
auto remote_address = socket.lowest_layer().remote_endpoint(error).address();
auto remote_port = socket.lowest_layer().remote_endpoint(error).port();
auto cancellable_timer = std::make_shared<asio::steady_timer>(io_context_);
cancellable_timers_.push_back(cancellable_timer);
if (ssl_data_)
co_spawn(io_context_, secureSession(std::move(socket), std::move(remote_address), remote_port, port_) || wait_until_cancelled(cancellable_timer), asio::detached);
else
co_spawn(io_context_, insecureSession(std::move(socket), std::move(remote_address), remote_port, port_) || wait_until_cancelled(cancellable_timer), asio::detached);
}
}
private:
static asio::awaitable<void> wait_until_cancelled(std::shared_ptr<asio::steady_timer> timer) {
timer->expires_at(asio::steady_timer::time_point::max());
co_await utils::net::async_wait(*timer);
}
std::vector<std::shared_ptr<asio::steady_timer>> cancellable_timers_;
};
utils::net::SslData createSslDataForServer() {
const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir();
utils::net::SslData ssl_data;
ssl_data.ca_loc = (executable_dir / "resources" / "ca_A.crt").string();
ssl_data.cert_loc = (executable_dir / "resources" / "localhost_by_A.pem").string();
ssl_data.key_loc = (executable_dir / "resources" / "localhost.key").string();
return ssl_data;
}
} // namespace
class PutTCPTestFixture {
public:
PutTCPTestFixture() :
controller_(minifi::test::utils::make_processor<PutTCP>("PutTCP")),
put_tcp_(controller_.getProcessor()) {
LogTestController::getInstance().setTrace<PutTCP>();
LogTestController::getInstance().setInfo<core::ProcessSession>();
LogTestController::getInstance().setTrace<utils::net::Server>();
REQUIRE(put_tcp_->setProperty(PutTCP::Hostname.name, "${literal('localhost')}"));
REQUIRE(put_tcp_->setProperty(PutTCP::Timeout.name, "200 ms"));
REQUIRE(put_tcp_->setProperty(PutTCP::OutgoingMessageDelimiter.name, "\n"));
}
PutTCPTestFixture(PutTCPTestFixture&&) = delete;
PutTCPTestFixture(const PutTCPTestFixture&) = delete;
PutTCPTestFixture& operator=(PutTCPTestFixture&&) = delete;
PutTCPTestFixture& operator=(const PutTCPTestFixture&) = delete;
~PutTCPTestFixture() {
stopServers();
}
void stopServers() {
for (auto& [port, server] : servers_) {
auto& cancellable_server = server.cancellable_server;
auto& server_thread = server.server_thread_;
if (cancellable_server)
cancellable_server->stop();
if (server_thread.joinable())
server_thread.join();
cancellable_server.reset();
}
}
size_t getNumberOfActiveSessions(std::optional<uint16_t> port = std::nullopt) {
if (auto cancellable_tcp_server = getServer(port)) {
return cancellable_tcp_server->getNumberOfSessions();
}
return -1;
}
void closeActiveConnections() {
for (auto& [port, server] : servers_) {
if (auto cancellable_tcp_server = getServer(port)) {
cancellable_tcp_server->cancelEverything();
}
}
std::this_thread::sleep_for(200ms);
}
auto trigger(std::string_view message, std::unordered_map<std::string, std::string> input_flow_file_attributes = {}) {
return controller_.trigger(message, std::move(input_flow_file_attributes));
}
auto getContent(const auto& flow_file) {
return controller_.plan->getContent(flow_file);
}
std::optional<utils::net::Message> tryDequeueReceivedMessage(std::optional<uint16_t> port = std::nullopt) {
auto timeout = 200ms;
auto interval = 10ms;
auto start_time = std::chrono::system_clock::now();
while (start_time + timeout > std::chrono::system_clock::now()) {
if (const auto result = getServer(port)->tryDequeue())
return result;
std::this_thread::sleep_for(interval);
}
return std::nullopt;
}
void addSSLContextToPutTCP(const std::filesystem::path& ca_cert, const std::optional<std::filesystem::path>& client_cert, const std::optional<std::filesystem::path>& client_cert_key) {
const std::filesystem::path ca_dir = minifi::utils::file::FileUtils::get_executable_dir() / "resources";
auto ssl_context_service_node = controller_.plan->addController("SSLContextService", "SSLContextService");
REQUIRE(controller_.plan->setProperty(ssl_context_service_node, SSLContextService::CACertificate, (ca_dir / ca_cert).string()));
if (client_cert) {
REQUIRE(controller_.plan->setProperty(ssl_context_service_node, SSLContextService::ClientCertificate, (ca_dir / *client_cert).string()));
}
if (client_cert_key) {
REQUIRE(controller_.plan->setProperty(ssl_context_service_node, SSLContextService::PrivateKey, (ca_dir / *client_cert_key).string()));
}
ssl_context_service_node->enable();
REQUIRE(put_tcp_->setProperty(PutTCP::SSLContextService.name, "SSLContextService"));
}
void setHostname(const std::string& hostname) {
REQUIRE(controller_.plan->setProperty(put_tcp_, PutTCP::Hostname, hostname));
}
void enableConnectionPerFlowFile() {
REQUIRE(controller_.plan->setProperty(put_tcp_, PutTCP::ConnectionPerFlowFile, "true"));
}
void setIdleConnectionExpiration(const std::string& idle_connection_expiration_str) {
REQUIRE(controller_.plan->setProperty(put_tcp_, PutTCP::IdleConnectionExpiration, idle_connection_expiration_str));
}
uint16_t addTCPServer() {
Server server;
uint16_t port = server.startTCPServer(std::nullopt);
servers_[port] = std::move(server);
return port;
}
uint16_t addSSLServer() {
auto ssl_server_options = utils::net::SslServerOptions{createSslDataForServer(), utils::net::ClientAuthOption::REQUIRED};
Server server;
uint16_t port = server.startTCPServer(ssl_server_options);
servers_[port] = std::move(server);
return port;
}
void setPutTCPPort(uint16_t port) {
CHECK(put_tcp_->setProperty(PutTCP::Port.name, utils::string::join_pack("${literal('", std::to_string(port), "')}")));
}
void setPutTCPPort(const std::string& port_str) {
CHECK(put_tcp_->setProperty(PutTCP::Port.name, port_str));
}
[[nodiscard]] uint16_t getSinglePort() const {
gsl_Expects(servers_.size() == 1);
return servers_.begin()->first;
}
private:
CancellableTcpServer* getServer(std::optional<uint16_t> port) {
if (!port)
port = getSinglePort();
return servers_.at(*port).cancellable_server.get();
}
test::SingleProcessorTestController controller_;
minifi::core::Processor* put_tcp_;
class Server {
public:
Server() = default;
uint16_t startTCPServer(std::optional<utils::net::SslServerOptions> ssl_server_options) {
gsl_Expects(!cancellable_server && !server_thread_.joinable());
cancellable_server = std::make_unique<CancellableTcpServer>(std::nullopt, 0, core::logging::LoggerFactory<utils::net::Server>::getLogger(), std::move(ssl_server_options), true, "\n");
server_thread_ = std::thread([this]() { cancellable_server->run(); });
REQUIRE(verifyEventHappenedInPollTime(250ms, [this] { return cancellable_server->getPort() != 0; }, 20ms));
return cancellable_server->getPort();
}
std::unique_ptr<CancellableTcpServer> cancellable_server;
std::thread server_thread_;
};
std::unordered_map<uint16_t, Server> servers_;
};
void trigger_expect_success(PutTCPTestFixture& test_fixture, const std::string_view message, std::unordered_map<std::string, std::string> input_flow_file_attributes = {}) {
const auto result = test_fixture.trigger(message, std::move(input_flow_file_attributes));
const auto& success_flow_files = result.at(PutTCP::Success);
CHECK(success_flow_files.size() == 1);
CHECK(result.at(PutTCP::Failure).empty());
if (!success_flow_files.empty())
CHECK(test_fixture.getContent(success_flow_files[0]) == message);
}
void trigger_expect_failure(PutTCPTestFixture& test_fixture, const std::string_view message) {
const auto result = test_fixture.trigger(message);
const auto &failure_flow_files = result.at(PutTCP::Failure);
CHECK(failure_flow_files.size() == 1);
CHECK(result.at(PutTCP::Success).empty());
if (!failure_flow_files.empty())
CHECK(test_fixture.getContent(failure_flow_files[0]) == message);
}
void receive_success(PutTCPTestFixture& test_fixture, const std::string_view expected_message, std::optional<uint16_t> port = std::nullopt) {
auto received_message = test_fixture.tryDequeueReceivedMessage(port);
CHECK(received_message);
if (received_message) {
CHECK(received_message->message_data == expected_message);
CHECK(received_message->protocol == utils::net::IpProtocol::TCP);
CHECK(!received_message->remote_address.to_string().empty());
}
}
constexpr std::string_view first_message = "message 1";
constexpr std::string_view second_message = "message 22";
constexpr std::string_view third_message = "message 333";
constexpr std::string_view fourth_message = "message 4444";
constexpr std::string_view fifth_message = "message 55555";
constexpr std::string_view sixth_message = "message 666666";
TEST_CASE("Server closes in-use socket", "[PutTCP]") {
PutTCPTestFixture test_fixture;
SECTION("No SSL") {
auto port = test_fixture.addTCPServer();
test_fixture.setPutTCPPort(port);
}
SECTION("SSL") {
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
auto port = test_fixture.addSSLServer();
test_fixture.setPutTCPPort(port);
}
trigger_expect_success(test_fixture, first_message);
trigger_expect_success(test_fixture, second_message);
trigger_expect_success(test_fixture, third_message);
receive_success(test_fixture, first_message);
receive_success(test_fixture, second_message);
receive_success(test_fixture, third_message);
CHECK(1 == test_fixture.getNumberOfActiveSessions());
test_fixture.closeActiveConnections();
trigger_expect_success(test_fixture, fourth_message);
trigger_expect_success(test_fixture, fifth_message);
trigger_expect_success(test_fixture, sixth_message);
test_fixture.tryDequeueReceivedMessage();
CHECK(LogTestController::getInstance().matchesRegex("warning.*with reused connection, retrying"));
CHECK(2 == test_fixture.getNumberOfActiveSessions());
}
TEST_CASE("Connection per flow file", "[PutTCP]") {
PutTCPTestFixture test_fixture;
SECTION("No SSL") {
auto port = test_fixture.addTCPServer();
test_fixture.setPutTCPPort(port);
}
SECTION("SSL") {
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
auto port = test_fixture.addSSLServer();
test_fixture.setPutTCPPort(port);
}
test_fixture.enableConnectionPerFlowFile();
trigger_expect_success(test_fixture, first_message);
trigger_expect_success(test_fixture, second_message);
trigger_expect_success(test_fixture, third_message);
receive_success(test_fixture, first_message);
receive_success(test_fixture, second_message);
receive_success(test_fixture, third_message);
trigger_expect_success(test_fixture, fourth_message);
trigger_expect_success(test_fixture, fifth_message);
trigger_expect_success(test_fixture, sixth_message);
receive_success(test_fixture, fourth_message);
receive_success(test_fixture, fifth_message);
receive_success(test_fixture, sixth_message);
CHECK(6 == test_fixture.getNumberOfActiveSessions());
}
TEST_CASE("PutTCP test invalid host", "[PutTCP]") {
PutTCPTestFixture test_fixture;
SECTION("No SSL") {
}
SECTION("SSL") {
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
}
test_fixture.setPutTCPPort(1235);
test_fixture.setHostname("invalid_hostname");
trigger_expect_failure(test_fixture, "message for invalid host");
}
TEST_CASE("PutTCP test invalid server", "[PutTCP]") {
PutTCPTestFixture test_fixture;
SECTION("No SSL") {
}
SECTION("SSL") {
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
}
test_fixture.setPutTCPPort(1235);
test_fixture.setHostname("localhost");
trigger_expect_failure(test_fixture, "message for invalid server");
}
TEST_CASE("PutTCP test non-routable server", "[PutTCP]") {
PutTCPTestFixture test_fixture;
SECTION("No SSL") {
}
SECTION("SSL") {
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
}
test_fixture.setHostname("192.168.255.255");
test_fixture.setPutTCPPort(1235);
trigger_expect_failure(test_fixture, "message for non-routable server");
}
TEST_CASE("PutTCP test invalid server cert", "[PutTCP]") {
PutTCPTestFixture test_fixture;
test_fixture.addSSLContextToPutTCP("ca_B.crt", "alice_by_B.pem", "alice.key");
test_fixture.setHostname("localhost");
auto port = test_fixture.addSSLServer();
test_fixture.setPutTCPPort(port);
trigger_expect_failure(test_fixture, "message for invalid-cert server");
CHECK(LogTestController::getInstance().matchesRegex("Handshake with .* failed", 0ms));
}
TEST_CASE("PutTCP test missing client cert", "[PutTCP]") {
PutTCPTestFixture test_fixture;
test_fixture.addSSLContextToPutTCP("ca_A.crt", std::nullopt, std::nullopt);
test_fixture.setHostname("localhost");
auto port = test_fixture.addSSLServer();
test_fixture.setPutTCPPort(port);
test_fixture.trigger("message for invalid-cert server");
CHECK(verifyLogLineVariantPresenceInPollTime(std::chrono::seconds(3), "peer did not return a certificate (SSL routines)", "failed due to asio.ssl error"));
}
TEST_CASE("PutTCP test idle connection expiration", "[PutTCP]") {
PutTCPTestFixture test_fixture;
SECTION("No SSL") {
auto port = test_fixture.addTCPServer();
test_fixture.setPutTCPPort(port);
}
SECTION("SSL") {
auto port = test_fixture.addSSLServer();
test_fixture.setPutTCPPort(port);
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
}
test_fixture.setIdleConnectionExpiration("100ms");
trigger_expect_success(test_fixture, first_message);
std::this_thread::sleep_for(110ms);
trigger_expect_success(test_fixture, second_message);
receive_success(test_fixture, first_message);
receive_success(test_fixture, second_message);
CHECK(2 == test_fixture.getNumberOfActiveSessions());
}
TEST_CASE("PutTCP test long flow file chunked sending", "[PutTCP]") {
PutTCPTestFixture test_fixture;
SECTION("No SSL") {
auto port = test_fixture.addTCPServer();
test_fixture.setPutTCPPort(port);
}
SECTION("SSL") {
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
auto port = test_fixture.addSSLServer();
test_fixture.setPutTCPPort(port);
}
std::string long_message(3500, 'a');
trigger_expect_success(test_fixture, long_message);
receive_success(test_fixture, long_message);
}
TEST_CASE("PutTCP test multiple servers", "[PutTCP]") {
PutTCPTestFixture test_fixture;
size_t number_of_servers = 5;
std::vector<uint16_t> ports;
SECTION("No SSL") {
for (size_t i = 0; i < number_of_servers; ++i) {
ports.push_back(test_fixture.addTCPServer());
}
}
SECTION("SSL") {
test_fixture.addSSLContextToPutTCP("ca_A.crt", "alice_by_A.pem", "alice.key");
for (size_t i = 0; i < number_of_servers; ++i) {
ports.push_back(test_fixture.addSSLServer());
}
}
test_fixture.setPutTCPPort("${tcp_port}");
for (auto i = 0; i < 3; ++i) {
for (auto& port : ports) {
std::string message = "Test message ";
message.append(std::to_string(port));
trigger_expect_success(test_fixture, message, {{"tcp_port", std::to_string(port)}});
receive_success(test_fixture, message, port);
}
}
for (auto& port : ports) {
CHECK(1 == test_fixture.getNumberOfActiveSessions(port));
}
}
} // namespace org::apache::nifi::minifi::processors