blob: b6cecf59a4b83367a61a7f2bb79218901c7da5db [file] [log] [blame]
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <filesystem>
#include <string>
#include "io/tls/TLSSocket.h"
#ifdef WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment(lib, "Ws2_32.lib")
using SocketDescriptor = SOCKET;
#else
using SocketDescriptor = int;
static constexpr SocketDescriptor INVALID_SOCKET = -1;
#endif /* WIN32 */
namespace minifi = org::apache::nifi::minifi;
class SimpleSSLTestServer {
struct SocketInitializer {
SocketInitializer() {
#ifdef WIN32
static WSADATA s_wsaData;
const int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData);
if (0 != iWinSockInitResult) {
throw std::runtime_error("Cannot initialize socket");
}
#endif
}
};
public:
SimpleSSLTestServer(const SSL_METHOD* method, int port, const std::filesystem::path& key_dir)
: port_(port), had_connection_(false) {
static SocketInitializer socket_initializer{};
minifi::io::OpenSSLInitializer::getInstance();
ctx_ = SSL_CTX_new(method);
configureContext(key_dir);
socket_descriptor_ = createSocket(port_);
}
~SimpleSSLTestServer() {
SSL_shutdown(ssl_);
SSL_free(ssl_);
SSL_CTX_free(ctx_);
}
void waitForConnection() {
server_read_thread_ = std::thread([this]() -> void {
SocketDescriptor client = accept(socket_descriptor_, nullptr, nullptr);
if (client != INVALID_SOCKET) {
ssl_ = SSL_new(ctx_);
SSL_set_fd(ssl_, client);
had_connection_ = (SSL_accept(ssl_) == 1);
}
});
}
void shutdownServer() {
#ifdef WIN32
shutdown(socket_descriptor_, SD_BOTH);
closesocket(socket_descriptor_);
#else
shutdown(socket_descriptor_, SHUT_RDWR);
close(socket_descriptor_);
#endif
server_read_thread_.join();
}
bool hadConnection() const {
return had_connection_;
}
int getPort() const {
struct sockaddr_in addr;
socklen_t addr_len = sizeof(addr);
assert(getsockname(socket_descriptor_, (struct sockaddr*)&addr, &addr_len) == 0);
return ntohs(addr.sin_port);
}
private:
SSL_CTX *ctx_ = nullptr;
SSL* ssl_ = nullptr;
int port_;
SocketDescriptor socket_descriptor_;
bool had_connection_;
std::thread server_read_thread_;
void configureContext(const std::filesystem::path& key_dir) {
SSL_CTX_set_ecdh_auto(ctx_, 1);
/* Set the key and cert */
assert(SSL_CTX_use_certificate_file(ctx_, (key_dir / "cn.crt.pem").string().c_str(), SSL_FILETYPE_PEM) == 1);
assert(SSL_CTX_use_PrivateKey_file(ctx_, (key_dir / "cn.ckey.pem").string().c_str(), SSL_FILETYPE_PEM) == 1);
}
static SocketDescriptor createSocket(int port) {
struct sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr.s_addr = htonl(INADDR_ANY);
SocketDescriptor socket_descriptor = socket(AF_INET, SOCK_STREAM, 0);
assert(socket_descriptor >= 0);
assert(bind(socket_descriptor, (struct sockaddr*)&addr, sizeof(addr)) >= 0);
assert(listen(socket_descriptor, 1) >= 0);
return socket_descriptor;
}
};