MINIFICPP-1118 - MiNiFi C++ on Windows stops running in a secure env when NiFi becomes unreachable
Signed-off-by: Arpad Boda <aboda@apache.org>
Approved by bakaid on GH
This closes #708
diff --git a/libminifi/include/io/tls/TLSSocket.h b/libminifi/include/io/tls/TLSSocket.h
index 7806d5a..74794ea 100644
--- a/libminifi/include/io/tls/TLSSocket.h
+++ b/libminifi/include/io/tls/TLSSocket.h
@@ -33,6 +33,7 @@
namespace minifi {
namespace io {
+#define TLS_GOOD 0
#define TLS_ERROR_CONTEXT 1
#define TLS_ERROR_PEM_MISSING 2
#define TLS_ERROR_CERT_MISSING 3
@@ -171,6 +172,8 @@
*/
int writeData(uint8_t *value, int size);
+ void closeStream(); // override
+
protected:
int writeData(uint8_t *value, int size, int fd);
@@ -191,9 +194,6 @@
SSL* ssl_;
std::mutex ssl_mutex_;
std::map<int, SSL*> ssl_map_;
-
- private:
- std::shared_ptr<logging::Logger> logger_;
};
} /* namespace io */
diff --git a/libminifi/opsys/posix/io/ClientSocket.h b/libminifi/opsys/posix/io/ClientSocket.h
index 739f7f9..cba975d 100644
--- a/libminifi/opsys/posix/io/ClientSocket.h
+++ b/libminifi/opsys/posix/io/ClientSocket.h
@@ -279,13 +279,13 @@
bool nonBlocking_;
- protected:
+ std::shared_ptr<logging::Logger> logger_;
+
void setPort(uint16_t port) {
port_ = port;
}
private:
- std::shared_ptr<logging::Logger> logger_;
static std::string init_hostname() {
char hostname[1024];
gethostname(hostname, 1024);
diff --git a/libminifi/opsys/win/io/ClientSocket.h b/libminifi/opsys/win/io/ClientSocket.h
index d5d279f..4e58c86 100644
--- a/libminifi/opsys/win/io/ClientSocket.h
+++ b/libminifi/opsys/win/io/ClientSocket.h
@@ -305,6 +305,9 @@
uint16_t listeners_;
bool nonBlocking_;
+
+ std::shared_ptr<logging::Logger> logger_;
+
private:
class SocketInitializer
@@ -328,9 +331,6 @@
static void initialize_socket() {
static SocketInitializer initialized;
}
-
-
- std::shared_ptr<logging::Logger> logger_;
static std::string init_hostname() {
diff --git a/libminifi/src/io/tls/TLSSocket.cpp b/libminifi/src/io/tls/TLSSocket.cpp
index 8b577eb..3138fc2 100644
--- a/libminifi/src/io/tls/TLSSocket.cpp
+++ b/libminifi/src/io/tls/TLSSocket.cpp
@@ -24,6 +24,7 @@
#include "io/tls/TLSSocket.h"
#include "io/tls/TLSUtils.h"
#include "properties/Configure.h"
+#include "utils/ScopeGuard.h"
#include "utils/StringUtils.h"
#include "core/Property.h"
#include "core/logging/LoggerConfiguration.h"
@@ -39,8 +40,8 @@
TLSContext::TLSContext(const std::shared_ptr<Configure> &configure, const std::shared_ptr<minifi::controllers::SSLContextService> &ssl_service)
: SocketContext(configure),
- error_value(0),
- ctx(0),
+ error_value(TLS_GOOD),
+ ctx(nullptr),
logger_(logging::LoggerFactory<TLSContext>::getLogger()),
configure_(configure),
ssl_service_(ssl_service) {
@@ -73,6 +74,12 @@
error_value = TLS_ERROR_CONTEXT;
return error_value;
}
+
+ utils::ScopeGuard ctxGuard([this]() {
+ SSL_CTX_free(ctx);
+ ctx = nullptr;
+ });
+
if (needClientCert) {
std::string certificate;
std::string privatekey;
@@ -84,6 +91,8 @@
error_value = TLS_ERROR_CERT_ERROR;
return error_value;
}
+ ctxGuard.disable();
+ error_value = TLS_GOOD;
return 0;
}
@@ -135,16 +144,23 @@
logger_->log_debug("Load/Verify Client Certificate OK. for %X and %X", this, ctx);
}
+ ctxGuard.disable();
+ error_value = TLS_GOOD;
return 0;
}
TLSSocket::~TLSSocket() {
+ closeStream();
+}
+
+void TLSSocket::closeStream() {
if (ssl_ != 0) {
SSL_free(ssl_);
ssl_ = nullptr;
}
- closeStream();
+ Socket::closeStream();
}
+
/**
* Constructor that accepts host name, port and listeners. With this
* contructor we will be creating a server socket
@@ -154,22 +170,21 @@
*/
TLSSocket::TLSSocket(const std::shared_ptr<TLSContext> &context, const std::string &hostname, const uint16_t port, const uint16_t listeners)
: Socket(context, hostname, port, listeners),
- ssl_(0),
- logger_(logging::LoggerFactory<TLSSocket>::getLogger()) {
+ ssl_(0) {
+ logger_ = logging::LoggerFactory<TLSSocket>::getLogger();
context_ = context;
}
TLSSocket::TLSSocket(const std::shared_ptr<TLSContext> &context, const std::string &hostname, const uint16_t port)
: Socket(context, hostname, port, 0),
- ssl_(0),
- logger_(logging::LoggerFactory<TLSSocket>::getLogger()) {
+ ssl_(0) {
+ logger_ = logging::LoggerFactory<TLSSocket>::getLogger();
context_ = context;
}
TLSSocket::TLSSocket(const TLSSocket &&d)
: Socket(std::move(d)),
- ssl_(0),
- logger_(std::move(d.logger_)) {
+ ssl_(0) {
context_ = d.context_;
}
@@ -182,9 +197,19 @@
setNonBlocking();
logger_->log_trace("Initializing TLSSocket %d", is_server);
int16_t ret = context_->initialize(is_server);
- Socket::initialize();
- if (!ret && listeners_ == 0) {
+ if (ret != 0) {
+ logger_->log_warn("Failed to initialize SSL context!");
+ return -1;
+ }
+
+ ret = Socket::initialize();
+ if (ret != 0) {
+ logger_->log_warn("Failed to initialise basic socket for TLS socket");
+ return -1;
+ }
+
+ if (listeners_ == 0) {
// we have s2s secure config
ssl_ = SSL_new(context_->getContext());
SSL_set_fd(ssl_, socket_file_descriptor_);
@@ -200,17 +225,10 @@
logger_->log_trace("want read");
return 0;
} else {
+ logger_->log_error("SSL socket connect failed to %s %d", requested_hostname_, port_);
+ closeStream();
return -1;
}
- logger_->log_error("SSL socket connect failed to %s %d", requested_hostname_, port_);
- SSL_free(ssl_);
- ssl_ = NULL;
-#ifdef WIN32
- closesocket(socket_file_descriptor_);
-#else
- close(socket_file_descriptor_);
-#endif
- return -1;
} else {
connected_ = true;
logger_->log_debug("SSL socket connect success to %s %d, on fd %d", requested_hostname_, port_, socket_file_descriptor_);
@@ -229,11 +247,7 @@
if (nullptr != fd_ssl) {
SSL_free(fd_ssl);
ssl_map_[fd] = nullptr;
-#ifdef WIN32
- closesocket(fd);
-#else
- close(fd);
-#endif
+ closeStream();
}
}
}
@@ -295,17 +309,10 @@
logger_->log_trace("want read");
return socket_file_descriptor_;
} else {
+ logger_->log_error("SSL socket connect failed to %s %d", requested_hostname_, port_);
+ closeStream();
return -1;
}
- logger_->log_error("SSL socket connect failed to %s %d", requested_hostname_, port_);
- SSL_free(ssl_);
- ssl_ = NULL;
-#ifdef WIN32
- closesocket(socket_file_descriptor_);
-#else
- close(socket_file_descriptor_);
-#endif
- return -1;
} else {
connected_ = true;
logger_->log_debug("SSL socket connect success to %s %d, on fd %d", requested_hostname_, port_, socket_file_descriptor_);
@@ -452,11 +459,7 @@
while (buflen) {
int16_t fd = select_descriptor(1000);
if (fd <= 0) {
-#ifdef WIN32
- closesocket(socket_file_descriptor_);
-#else
- close(socket_file_descriptor_);
-#endif
+ closeStream();
return -1;
}