| /** |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "TestUtils.h" |
| |
| #include <type_traits> |
| |
| #ifdef WIN32 |
| #include <windows.h> |
| #include <aclapi.h> |
| #endif |
| |
| #include "minifi-cpp/utils/gsl.h" |
| |
| #ifdef WIN32 |
| namespace { |
| |
| void setAclOnFileOrDirectory(std::string file_name, DWORD perms, ACCESS_MODE perm_options) { |
| PSECURITY_DESCRIPTOR security_descriptor = nullptr; |
| const auto security_descriptor_deleter = gsl::finally([&security_descriptor] { if (security_descriptor) { LocalFree((HLOCAL) security_descriptor); } }); |
| |
| PACL old_acl = nullptr; // GetNamedSecurityInfo will set this to a non-owning pointer to a field inside security_descriptor: no need to free it |
| if (GetNamedSecurityInfo(file_name.c_str(), SE_FILE_OBJECT, DACL_SECURITY_INFORMATION, NULL, NULL, &old_acl, NULL, &security_descriptor) != ERROR_SUCCESS) { |
| throw std::runtime_error("Could not get security info for file: " + file_name); |
| } |
| |
| char trustee_name[] = "Everyone"; |
| EXPLICIT_ACCESS explicit_access = { |
| .grfAccessPermissions = perms, |
| .grfAccessMode = perm_options, |
| .grfInheritance = CONTAINER_INHERIT_ACE | OBJECT_INHERIT_ACE, |
| .Trustee = { .TrusteeForm = TRUSTEE_IS_NAME, .ptstrName = trustee_name } |
| }; |
| |
| PACL new_acl = nullptr; |
| const auto new_acl_deleter = gsl::finally([&new_acl] { if (new_acl) { LocalFree((HLOCAL) new_acl); } }); |
| |
| if (SetEntriesInAcl(1, &explicit_access, old_acl, &new_acl) != ERROR_SUCCESS) { |
| throw std::runtime_error("Could not create new ACL for file: " + file_name); |
| } |
| |
| if (SetNamedSecurityInfo(file_name.data(), SE_FILE_OBJECT, DACL_SECURITY_INFORMATION, NULL, NULL, new_acl, NULL) != ERROR_SUCCESS) { |
| throw std::runtime_error("Could not set the new ACL for file: " + file_name); |
| } |
| } |
| |
| } // namespace |
| #endif |
| |
| namespace org::apache::nifi::minifi::test::utils { |
| |
| std::filesystem::path putFileToDir(const std::filesystem::path& dir_path, const std::filesystem::path& file_name, const std::string& content) { |
| auto file_path = dir_path/file_name; |
| std::ofstream out_file(file_path, std::ios::binary | std::ios::out); |
| assert(out_file.is_open()); |
| out_file << content; |
| return file_path; |
| } |
| |
| std::string getFileContent(const std::filesystem::path& file_name) { |
| std::ifstream file_handle(file_name, std::ios::binary | std::ios::in); |
| assert(file_handle.is_open()); |
| std::string file_content{ (std::istreambuf_iterator<char>(file_handle)), (std::istreambuf_iterator<char>()) }; |
| return file_content; |
| } |
| |
| void makeFileOrDirectoryNotWritable(const std::filesystem::path& file_name) { |
| #ifdef WIN32 |
| setAclOnFileOrDirectory(file_name.string(), FILE_GENERIC_WRITE, DENY_ACCESS); |
| #else |
| std::filesystem::permissions(file_name, std::filesystem::perms::owner_write, std::filesystem::perm_options::remove); |
| #endif |
| } |
| |
| void makeFileOrDirectoryWritable(const std::filesystem::path& file_name) { |
| #ifdef WIN32 |
| setAclOnFileOrDirectory(file_name.string(), FILE_GENERIC_WRITE, GRANT_ACCESS); |
| #else |
| std::filesystem::permissions(file_name, std::filesystem::perms::owner_write, std::filesystem::perm_options::add); |
| #endif |
| } |
| |
| void ManualClock::advance(std::chrono::milliseconds elapsed_time) { |
| if (elapsed_time.count() < 0) { |
| throw std::logic_error("A steady clock can only be advanced forward"); |
| } |
| std::lock_guard lock(mtx_); |
| time_ += elapsed_time; |
| for (auto* cv : cvs_) { |
| cv->notify_all(); |
| } |
| } |
| |
| bool ManualClock::wait_until(std::condition_variable& cv, std::unique_lock<std::mutex>& lck, std::chrono::milliseconds time, const std::function<bool()>& pred) { |
| std::chrono::milliseconds now; |
| { |
| std::unique_lock lock(mtx_); |
| now = time_; |
| cvs_.insert(&cv); |
| } |
| cv.wait_for(lck, time - now, [&] { |
| now = timeSinceEpoch(); |
| return now >= time || pred(); |
| }); |
| { |
| std::unique_lock lock(mtx_); |
| cvs_.erase(&cv); |
| } |
| return pred(); |
| } |
| |
| void matchJSON(const internal::JsonContext& ctx, const rapidjson::Value& actual, const rapidjson::Value& expected, bool strict) { |
| if (expected.IsObject()) { |
| REQUIRE_WARN(actual.IsObject(), fmt::format("Expected object at {}", ctx.path())); |
| for (const auto& expected_member : expected.GetObject()) { |
| std::string_view name{expected_member.name.GetString(), expected_member.name.GetStringLength()}; |
| REQUIRE_WARN(actual.HasMember(expected_member.name), fmt::format("Expected member '{}' at {}", name, ctx.path())); |
| matchJSON(internal::JsonContext{.parent = &ctx, .member = name}, actual[expected_member.name], expected_member.value, strict); |
| } |
| if (strict) { |
| for (const auto& actual_member : actual.GetObject()) { |
| std::string_view name{actual_member.name.GetString(), actual_member.name.GetStringLength()}; |
| REQUIRE_WARN(expected.HasMember(actual_member.name), fmt::format("Did not expect member '{}' at {}", name, ctx.path())); |
| } |
| } |
| } else if (expected.IsArray()) { |
| REQUIRE_WARN(actual.IsArray(), fmt::format("Expected array at {}", ctx.path())); |
| REQUIRE_WARN(actual.Size() == expected.Size(), fmt::format("Expected array of length {}, got {} at {}", expected.Size(), actual.Size(), ctx.path())); |
| for (rapidjson::SizeType idx{0}; idx < expected.Size(); ++idx) { |
| matchJSON(internal::JsonContext{.parent = &ctx, .member = std::to_string(idx)}, actual[idx], expected[idx], strict); |
| } |
| } else { |
| REQUIRE_WARN(actual == expected, fmt::format("Values are not equal at {}", ctx.path())); |
| } |
| } |
| |
| void verifyJSON(const std::string& actual_str, const std::string& expected_str, bool strict) { |
| rapidjson::Document actual; |
| rapidjson::Document expected; |
| REQUIRE_FALSE(actual.Parse(actual_str.c_str()).HasParseError()); |
| REQUIRE_FALSE(expected.Parse(expected_str.c_str()).HasParseError()); |
| |
| matchJSON(internal::JsonContext{}, actual, expected, strict); |
| } |
| |
| bool countLogOccurrencesUntil(const std::string& pattern, |
| const size_t occurrences, |
| const std::chrono::milliseconds max_duration, |
| const std::chrono::milliseconds wait_time) { |
| auto start_time = std::chrono::steady_clock::now(); |
| while (std::chrono::steady_clock::now() < start_time + max_duration) { |
| if (LogTestController::getInstance().countOccurrences(pattern) == occurrences) |
| return true; |
| std::this_thread::sleep_for(wait_time); |
| } |
| return false; |
| } |
| |
| std::error_code sendMessagesViaTCP(const std::vector<std::string_view>& contents, const asio::ip::tcp::endpoint& remote_endpoint, const std::optional<std::string_view> delimiter) { |
| asio::io_context io_context; |
| asio::ip::tcp::socket socket(io_context); |
| std::error_code err; |
| std::ignore = socket.connect(remote_endpoint, err); |
| if (err) |
| return err; |
| for (auto& content : contents) { |
| std::string tcp_message(content); |
| if (delimiter) |
| tcp_message += *delimiter; |
| asio::write(socket, asio::buffer(tcp_message, tcp_message.size()), err); |
| if (err) |
| return err; |
| } |
| return {}; |
| } |
| |
| std::error_code sendUdpDatagram(const asio::const_buffer content, const asio::ip::udp::endpoint& remote_endpoint) { |
| asio::io_context io_context; |
| asio::ip::udp::socket socket(io_context); |
| std::error_code err; |
| std::ignore = socket.open(remote_endpoint.protocol(), err); |
| if (err) |
| return err; |
| socket.send_to(content, remote_endpoint, 0, err); |
| return err; |
| } |
| |
| std::error_code sendUdpDatagram(const std::span<std::byte const> content, const asio::ip::udp::endpoint& remote_endpoint) { |
| return sendUdpDatagram(asio::const_buffer(content.data(), content.size()), remote_endpoint); |
| } |
| |
| std::error_code sendUdpDatagram(const std::string_view content, const asio::ip::udp::endpoint& remote_endpoint) { |
| return sendUdpDatagram(asio::buffer(content), remote_endpoint); |
| } |
| |
| bool isIPv6Disabled() { |
| asio::io_context io_context; |
| std::error_code error_code; |
| asio::ip::tcp::socket socket_tcp(io_context); |
| std::ignore = socket_tcp.connect(asio::ip::tcp::endpoint(asio::ip::address_v6::loopback(), 10), error_code); |
| return error_code.value() == EADDRNOTAVAIL; |
| } |
| |
| |
| std::error_code sendMessagesViaSSL(const std::vector<std::string_view>& contents, |
| const asio::ip::tcp::endpoint& remote_endpoint, |
| const std::filesystem::path& ca_cert_path, |
| const std::optional<minifi::utils::net::SslData>& ssl_data, |
| asio::ssl::context::method method) { |
| asio::ssl::context ctx(method); |
| ctx.load_verify_file(ca_cert_path.string()); |
| if (ssl_data) { |
| ctx.set_verify_mode(asio::ssl::verify_peer); |
| ctx.use_certificate_file(ssl_data->cert_loc.string(), asio::ssl::context::pem); |
| ctx.use_private_key_file(ssl_data->key_loc.string(), asio::ssl::context::pem); |
| ctx.set_password_callback([password = ssl_data->key_pw](std::size_t&, asio::ssl::context_base::password_purpose&) { return password; }); |
| } |
| asio::io_context io_context; |
| asio::ssl::stream<asio::ip::tcp::socket> socket(io_context, ctx); |
| auto shutdown_socket = gsl::finally([&] { |
| asio::error_code ec; |
| std::ignore = socket.lowest_layer().cancel(ec); |
| std::ignore = socket.shutdown(ec); |
| }); |
| asio::error_code err; |
| std::ignore = socket.lowest_layer().connect(remote_endpoint, err); |
| if (err) { |
| return err; |
| } |
| std::ignore = socket.handshake(asio::ssl::stream_base::client, err); |
| if (err) { |
| return err; |
| } |
| for (auto& content : contents) { |
| std::string tcp_message(content); |
| tcp_message += '\n'; |
| asio::write(socket, asio::buffer(tcp_message, tcp_message.size()), err); |
| if (err) { |
| return err; |
| } |
| } |
| return {}; |
| } |
| |
| } // namespace org::apache::nifi::minifi::test::utils |