blob: c27dd5d4e1f180a6d152b57ffc4dd0de0bdc2629 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "Catch.h"
#include "processors/GetTCP.h"
#include "SingleProcessorTestController.h"
#include "Utils.h"
#include "utils/net/AsioCoro.h"
#include "utils/net/AsioSocketUtils.h"
#include "controllers/SSLContextService.h"
#include "range/v3/algorithm/contains.hpp"
#include "utils/gsl.h"
using GetTCP = org::apache::nifi::minifi::processors::GetTCP;
using namespace std::literals::chrono_literals;
namespace org::apache::nifi::minifi::test {
void check_for_attributes(core::FlowFile& flow_file, uint16_t port) {
const auto local_addresses = {"127.0.0.1:" + std::to_string(port), "::ffff:127.0.0.1:" + std::to_string(port), "::1:" + std::to_string(port)};
CHECK(ranges::contains(local_addresses, flow_file.getAttribute(GetTCP::SourceEndpoint.name)));
}
minifi::utils::net::SslData createSslDataForServer() {
const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir();
minifi::utils::net::SslData ssl_data;
ssl_data.ca_loc = (executable_dir / "resources" / "ca_A.crt").string();
ssl_data.cert_loc = (executable_dir / "resources" / "localhost_by_A.pem").string();
ssl_data.key_loc = (executable_dir / "resources" / "localhost.key").string();
return ssl_data;
}
void addSslContextServiceTo(SingleProcessorTestController& controller) {
auto ssl_context_service = controller.plan->addController("SSLContextService", "SSLContextService");
LogTestController::getInstance().setTrace<GetTCP>();
const auto executable_dir = minifi::utils::file::FileUtils::get_executable_dir();
REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::CACertificate, (executable_dir / "resources" / "ca_A.crt").string()));
REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::ClientCertificate, (executable_dir / "resources" / "alice_by_A.pem").string()));
REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::PrivateKey, (executable_dir / "resources" / "alice.key").string()));
ssl_context_service->enable();
}
class TcpTestServer {
public:
void run() {
server_thread_ = std::thread([&]() {
asio::co_spawn(io_context_, listenAndSendMessages(), asio::detached);
io_context_.run();
});
}
void queueMessage(std::string message) {
messages_to_send_.enqueue(std::move(message));
}
void enableSSL() {
const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir();
asio::ssl::context ssl_context(asio::ssl::context::tls_server);
ssl_context.set_options(minifi::utils::net::MINIFI_SSL_OPTIONS);
ssl_context.set_password_callback([key_pw = "Password12"](std::size_t&, asio::ssl::context_base::password_purpose&) { return key_pw; });
ssl_context.use_certificate_file((executable_dir / "resources" / "localhost_by_A.pem").string(), asio::ssl::context::pem);
ssl_context.use_private_key_file((executable_dir / "resources" / "localhost.key").string(), asio::ssl::context::pem);
ssl_context.load_verify_file((executable_dir / "resources" / "ca_A.crt").string());
ssl_context.set_verify_mode(asio::ssl::verify_peer);
ssl_context_ = std::move(ssl_context);
}
uint16_t getPort() const {
return port_;
}
~TcpTestServer() {
io_context_.stop();
if (server_thread_.joinable())
server_thread_.join();
}
private:
asio::awaitable<void> sendMessages(auto& socket) {
while (true) {
std::string message_to_send;
if (!messages_to_send_.tryDequeue(message_to_send)) {
co_await minifi::utils::net::async_wait(10ms);
continue;
}
co_await asio::async_write(socket, asio::buffer(message_to_send), minifi::utils::net::use_nothrow_awaitable);
}
}
asio::awaitable<void> secureSession(asio::ip::tcp::socket socket) {
gsl_Expects(ssl_context_);
minifi::utils::net::SslSocket ssl_socket(std::move(socket), *ssl_context_);
auto [handshake_error] = co_await ssl_socket.async_handshake(minifi::utils::net::HandshakeType::server, minifi::utils::net::use_nothrow_awaitable);
if (handshake_error) {
co_return;
}
co_await sendMessages(ssl_socket);
asio::error_code ec;
ssl_socket.lowest_layer().cancel(ec);
co_await ssl_socket.async_shutdown(minifi::utils::net::use_nothrow_awaitable);
}
asio::awaitable<void> insecureSession(asio::ip::tcp::socket socket) {
co_await sendMessages(socket);
}
asio::awaitable<void> listenAndSendMessages() {
asio::ip::tcp::acceptor acceptor(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v6(), port_));
if (port_ == 0)
port_ = acceptor.local_endpoint().port();
while (true) {
auto [accept_error, socket] = co_await acceptor.async_accept(minifi::utils::net::use_nothrow_awaitable);
if (accept_error) {
co_return;
}
if (ssl_context_)
co_spawn(io_context_, secureSession(std::move(socket)), asio::detached);
else
co_spawn(io_context_, insecureSession(std::move(socket)), asio::detached);
}
}
std::optional<asio::ssl::context> ssl_context_;
minifi::utils::ConcurrentQueue<std::string> messages_to_send_;
std::atomic<uint16_t> port_ = 0;
std::thread server_thread_;
asio::io_context io_context_;
};
TEST_CASE("GetTCP test with delimiter", "[GetTCP]") {
const auto get_tcp = std::make_shared<GetTCP>("GetTCP");
SingleProcessorTestController controller{get_tcp};
LogTestController::getInstance().setTrace<GetTCP>();
REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2"));
TcpTestServer tcp_test_server;
SECTION("No SSL") {}
SECTION("SSL") {
addSslContextServiceTo(controller);
tcp_test_server.enableSSL();
REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService"));
}
tcp_test_server.queueMessage("Hello\n");
tcp_test_server.run();
REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return tcp_test_server.getPort() != 0; }, 20ms));
REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", tcp_test_server.getPort())));
controller.plan->scheduleProcessor(get_tcp);
ProcessorTriggerResult result;
REQUIRE(controller.triggerUntil({{GetTCP::Success, 1}}, result, 1s, 50ms));
CHECK(controller.plan->getContent(result.at(GetTCP::Success)[0]) == "Hello\n");
check_for_attributes(*result.at(GetTCP::Success)[0], tcp_test_server.getPort());
}
TEST_CASE("GetTCP test with too large message", "[GetTCP]") {
const auto get_tcp = std::make_shared<GetTCP>("GetTCP");
SingleProcessorTestController controller{get_tcp};
LogTestController::getInstance().setTrace<GetTCP>();
REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2"));
REQUIRE(get_tcp->setProperty(GetTCP::MaxMessageSize, "10"));
REQUIRE(get_tcp->setProperty(GetTCP::MessageDelimiter, "\r"));
TcpTestServer tcp_test_server;
SECTION("No SSL") {}
SECTION("SSL") {
addSslContextServiceTo(controller);
tcp_test_server.enableSSL();
REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService"));
}
tcp_test_server.queueMessage("abcdefghijklmnopqrstuvwxyz\rBye\r");
tcp_test_server.run();
REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return tcp_test_server.getPort() != 0; }, 20ms));
REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", tcp_test_server.getPort())));
controller.plan->scheduleProcessor(get_tcp);
ProcessorTriggerResult result;
REQUIRE(controller.triggerUntil({{GetTCP::Success, 1}}, result, 1s, 50ms));
REQUIRE(result.at(GetTCP::Partial).size() == 3);
REQUIRE(result.at(GetTCP::Success).size() == 1);
CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[0]) == "abcdefghij");
CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[1]) == "klmnopqrst");
CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[2]) == "uvwxyz\r");
CHECK(controller.plan->getContent(result.at(GetTCP::Success)[0]) == "Bye\r");
check_for_attributes(*result.at(GetTCP::Partial)[0], tcp_test_server.getPort());
check_for_attributes(*result.at(GetTCP::Partial)[1], tcp_test_server.getPort());
check_for_attributes(*result.at(GetTCP::Partial)[2], tcp_test_server.getPort());
check_for_attributes(*result.at(GetTCP::Success)[0], tcp_test_server.getPort());
}
TEST_CASE("GetTCP test multiple endpoints", "[GetTCP]") {
const auto get_tcp = std::make_shared<GetTCP>("GetTCP");
SingleProcessorTestController controller{get_tcp};
LogTestController::getInstance().setTrace<GetTCP>();
REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2"));
TcpTestServer server_1;
TcpTestServer server_2;
SECTION("No SSL") {}
SECTION("SSL") {
addSslContextServiceTo(controller);
server_1.enableSSL();
server_2.enableSSL();
REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService"));
}
server_1.queueMessage("abcdefghijklmnopqrstuvwxyz\nBye\n");
server_1.run();
server_2.queueMessage("012345678901234567890\nAuf Wiedersehen\n");
server_2.run();
REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return server_1.getPort() != 0 && server_2.getPort() != 0; }, 20ms));
REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{},localhost:{}", server_1.getPort(), server_2.getPort())));
controller.plan->scheduleProcessor(get_tcp);
ProcessorTriggerResult result;
CHECK(controller.triggerUntil({{GetTCP::Success, 4}}, result, 1s, 50ms));
CHECK(result.at(GetTCP::Success).size() == 4);
std::vector<std::string> success_flow_file_contents;
for (const auto& flow_file: result.at(GetTCP::Success)) {
success_flow_file_contents.push_back(controller.plan->getContent(flow_file));
}
CHECK(ranges::contains(success_flow_file_contents, "abcdefghijklmnopqrstuvwxyz\n"));
CHECK(ranges::contains(success_flow_file_contents, "Bye\n"));
CHECK(ranges::contains(success_flow_file_contents, "012345678901234567890\n"));
CHECK(ranges::contains(success_flow_file_contents, "Auf Wiedersehen\n"));
}
TEST_CASE("GetTCP max queue and max batch size test", "[GetTCP]") {
const auto get_tcp = std::make_shared<GetTCP>("GetTCP");
SingleProcessorTestController controller{get_tcp};
LogTestController::getInstance().setTrace<GetTCP>();
REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "10"));
REQUIRE(get_tcp->setProperty(GetTCP::MaxQueueSize, "50"));
TcpTestServer server;
SECTION("No SSL") {}
SECTION("SSL") {
addSslContextServiceTo(controller);
server.enableSSL();
REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService"));
}
LogTestController::getInstance().setWarn<GetTCP>();
for (auto i = 0; i < 100; ++i) {
server.queueMessage("some_message\n");
}
server.run();
REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return server.getPort() != 0; }, 20ms));
REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", server.getPort())));
controller.plan->scheduleProcessor(get_tcp);
CHECK(utils::countLogOccurrencesUntil("Queue is full. TCP message ignored.", 50, 300ms, 50ms));
CHECK(controller.trigger().at(GetTCP::Success).size() == 10);
CHECK(controller.trigger().at(GetTCP::Success).size() == 10);
CHECK(controller.trigger().at(GetTCP::Success).size() == 10);
CHECK(controller.trigger().at(GetTCP::Success).size() == 10);
CHECK(controller.trigger().at(GetTCP::Success).size() == 10);
CHECK(controller.trigger().at(GetTCP::Success).empty());
}
} // namespace org::apache::nifi::minifi::test