blob: b1cda1ac4fa1a4c69c54edc875b47e10a50a6e91 [file] [log] [blame]
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "io/ClientSocket.h"
#ifndef WIN32
#include <netinet/tcp.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <ifaddrs.h>
#include <unistd.h>
#else
#include <WS2tcpip.h>
#pragma comment(lib, "Ws2_32.lib")
#endif /* !WIN32 */
#ifdef WIN32
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif
#include <memory>
#include <utility>
#include <vector>
#include <cerrno>
#include <string>
#include "Exception.h"
#include <system_error>
#include <cinttypes>
#include <Exception.h>
#include <utils/Deleters.h>
#include "io/validation.h"
#include "core/logging/LoggerConfiguration.h"
#include "utils/file/FileUtils.h"
#include "utils/GeneralUtils.h"
namespace util = org::apache::nifi::minifi::utils;
namespace mio = org::apache::nifi::minifi::io;
namespace {
std::string get_last_getaddrinfo_err_str(int getaddrinfo_result) {
#ifdef WIN32
(void)getaddrinfo_result; // against unused warnings on windows
return mio::get_last_socket_error_message();
#else
return gai_strerror(getaddrinfo_result);
#endif /* WIN32 */
}
std::string sockaddr_ntop(const sockaddr* const sa) {
std::string result;
if (sa->sa_family == AF_INET) {
sockaddr_in sa_in{};
std::memcpy(reinterpret_cast<void*>(&sa_in), sa, sizeof(sockaddr_in));
result.resize(INET_ADDRSTRLEN);
if (inet_ntop(AF_INET, &sa_in.sin_addr, &result[0], INET_ADDRSTRLEN) == nullptr) {
throw minifi::Exception{ minifi::ExceptionType::GENERAL_EXCEPTION, mio::get_last_socket_error_message() };
}
} else if (sa->sa_family == AF_INET6) {
sockaddr_in6 sa_in6{};
std::memcpy(reinterpret_cast<void*>(&sa_in6), sa, sizeof(sockaddr_in6));
result.resize(INET6_ADDRSTRLEN);
if (inet_ntop(AF_INET6, &sa_in6.sin6_addr, &result[0], INET6_ADDRSTRLEN) == nullptr) {
throw minifi::Exception{ minifi::ExceptionType::GENERAL_EXCEPTION, mio::get_last_socket_error_message() };
}
} else {
throw minifi::Exception{ minifi::ExceptionType::GENERAL_EXCEPTION, "sockaddr_ntop: unknown address family" };
}
result.resize(strlen(result.c_str())); // discard remaining null bytes at the end
return result;
}
template<typename T, typename Pred, typename Adv>
auto find_if_custom_linked_list(T* const list, const Adv advance_func, const Pred predicate) ->
typename std::enable_if<std::is_convertible<decltype(advance_func(std::declval<T*>())), T*>::value && std::is_convertible<decltype(predicate(std::declval<T*>())), bool>::value, T*>::type
{
for (T* it = list; it; it = advance_func(it)) {
if (predicate(it)) return it;
}
return nullptr;
}
#ifndef WIN32
std::error_code bind_to_local_network_interface(const minifi::io::SocketDescriptor fd, const minifi::io::NetworkInterface& interface) {
using ifaddrs_uniq_ptr = std::unique_ptr<ifaddrs, util::ifaddrs_deleter>;
const auto if_list_ptr = []() -> ifaddrs_uniq_ptr {
ifaddrs *list = nullptr;
const auto get_ifa_success = getifaddrs(&list) == 0;
assert(get_ifa_success || !list);
(void)get_ifa_success; // unused in release builds
return ifaddrs_uniq_ptr{ list };
}();
if (!if_list_ptr) { return { errno, std::generic_category() }; }
const auto advance_func = [](const ifaddrs *const p) { return p->ifa_next; };
const auto predicate = [&interface](const ifaddrs *const item) {
return item->ifa_addr && item->ifa_name && (item->ifa_addr->sa_family == AF_INET || item->ifa_addr->sa_family == AF_INET6)
&& item->ifa_name == interface.getInterface();
};
const auto *const itemFound = find_if_custom_linked_list(if_list_ptr.get(), advance_func, predicate);
if (itemFound == nullptr) { return std::make_error_code(std::errc::no_such_device_or_address); }
const socklen_t addrlen = itemFound->ifa_addr->sa_family == AF_INET ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
if (bind(fd, itemFound->ifa_addr, addrlen) != 0) { return { errno, std::generic_category() }; }
return {};
}
#endif /* !WIN32 */
std::error_code set_non_blocking(const minifi::io::SocketDescriptor fd) noexcept {
#ifndef WIN32
if (fcntl(fd, F_SETFL, O_NONBLOCK) < 0) {
return { errno, std::generic_category() };
}
#else
u_long iMode = 1;
if (ioctlsocket(fd, FIONBIO, &iMode) == SOCKET_ERROR) {
return { WSAGetLastError(), std::system_category() };
}
#endif /* !WIN32 */
return {};
}
} // namespace
namespace org {
namespace apache {
namespace nifi {
namespace minifi {
namespace io {
std::string get_last_socket_error_message() {
#ifdef WIN32
const auto error_code = WSAGetLastError();
#else
const auto error_code = errno;
#endif /* WIN32 */
return std::system_category().message(error_code);
}
bool valid_socket(const SocketDescriptor fd) noexcept {
#ifdef WIN32
return fd != INVALID_SOCKET && fd >= 0;
#else
return fd >= 0;
#endif /* WIN32 */
}
Socket::Socket(const std::shared_ptr<SocketContext>& /*context*/, std::string hostname, const uint16_t port, const uint16_t listeners)
: requested_hostname_(std::move(hostname)),
port_(port),
listeners_(listeners),
logger_(logging::LoggerFactory<Socket>::getLogger()) {
FD_ZERO(&total_list_);
FD_ZERO(&read_fds_);
initialize_socket();
}
Socket::Socket(const std::shared_ptr<SocketContext>& context, std::string hostname, const uint16_t port)
: Socket(context, std::move(hostname), port, 0) {
}
// total_list_ and read_fds_ have to use parentheses for initialization due to CWG 1467
// http://www.open-std.org/jtc1/sc22/wg21/docs/cwg_defects.html#1467
// Language defect fix was applied to GCC 5 and Clang 4, but at the time of writing this comment, we support GCC 4.8
Socket::Socket(Socket &&other) noexcept
: requested_hostname_{ std::move(other.requested_hostname_) },
canonical_hostname_{ std::move(other.canonical_hostname_) },
port_{ other.port_ },
is_loopback_only_{ other.is_loopback_only_ },
local_network_interface_{ std::move(other.local_network_interface_) },
socket_file_descriptor_{ other.socket_file_descriptor_ },
total_list_(other.total_list_),
read_fds_(other.read_fds_),
socket_max_{ other.socket_max_.load() },
total_written_{ other.total_written_.load() },
total_read_{ other.total_read_.load() },
listeners_{ other.listeners_ },
nonBlocking_{ other.nonBlocking_ },
logger_{ other.logger_ }
{
other = Socket{ {}, {}, {} };
}
Socket& Socket::operator=(Socket &&other) noexcept {
if (&other == this) return *this;
requested_hostname_ = util::exchange(other.requested_hostname_, "");
canonical_hostname_ = util::exchange(other.canonical_hostname_, "");
port_ = util::exchange(other.port_, 0);
is_loopback_only_ = util::exchange(other.is_loopback_only_, false);
local_network_interface_ = util::exchange(other.local_network_interface_, {});
socket_file_descriptor_ = util::exchange(other.socket_file_descriptor_, INVALID_SOCKET);
total_list_ = other.total_list_;
FD_ZERO(&other.total_list_);
read_fds_ = other.read_fds_;
FD_ZERO(&other.read_fds_);
socket_max_.exchange(other.socket_max_);
other.socket_max_.exchange(0);
total_written_.exchange(other.total_written_);
other.total_written_.exchange(0);
total_read_.exchange(other.total_read_);
other.total_read_.exchange(0);
listeners_ = util::exchange(other.listeners_, 0);
nonBlocking_ = util::exchange(other.nonBlocking_, false);
logger_ = other.logger_;
return *this;
}
Socket::~Socket() {
close();
}
void Socket::close() {
if (valid_socket(socket_file_descriptor_)) {
logging::LOG_DEBUG(logger_) << "Closing " << socket_file_descriptor_;
#ifdef WIN32
closesocket(socket_file_descriptor_);
#else
::close(socket_file_descriptor_);
#endif
socket_file_descriptor_ = INVALID_SOCKET;
}
if (total_written_ > 0) {
local_network_interface_.log_write(gsl::narrow<uint32_t>(total_written_.load()));
total_written_ = 0;
}
if (total_read_ > 0) {
local_network_interface_.log_read(gsl::narrow<uint32_t>(total_read_.load()));
total_read_ = 0;
}
}
void Socket::setNonBlocking() {
if (listeners_ <= 0) {
nonBlocking_ = true;
}
}
int8_t Socket::createConnection(const addrinfo* const destination_addresses) {
for (const auto *current_addr = destination_addresses; current_addr; current_addr = current_addr->ai_next) {
if (!valid_socket(socket_file_descriptor_ = socket(current_addr->ai_family, current_addr->ai_socktype, current_addr->ai_protocol))) {
logger_->log_warn("socket: %s", get_last_socket_error_message());
continue;
}
setSocketOptions(socket_file_descriptor_);
if (listeners_ > 0) {
// server socket
const auto bind_result = bind(socket_file_descriptor_, current_addr->ai_addr, current_addr->ai_addrlen);
if (bind_result == SOCKET_ERROR) {
logger_->log_warn("bind: %s", get_last_socket_error_message());
close();
continue;
}
const auto listen_result = listen(socket_file_descriptor_, listeners_);
if (listen_result == SOCKET_ERROR) {
logger_->log_warn("listen: %s", get_last_socket_error_message());
close();
continue;
}
logger_->log_info("Listening on %s:%" PRIu16 " with backlog %" PRIu16, sockaddr_ntop(current_addr->ai_addr), port_, listeners_);
} else {
// client socket
#ifndef WIN32
if (!local_network_interface_.getInterface().empty()) {
const auto err = bind_to_local_network_interface(socket_file_descriptor_, local_network_interface_);
if (err) logger_->log_info("Bind to interface %s failed %s", local_network_interface_.getInterface(), err.message());
else logger_->log_info("Bind to interface %s", local_network_interface_.getInterface());
}
#endif /* !WIN32 */
const auto connect_result = connect(socket_file_descriptor_, current_addr->ai_addr, current_addr->ai_addrlen);
if (connect_result == SOCKET_ERROR) {
logger_->log_warn("Couldn't connect to %s:%" PRIu16 ": %s", sockaddr_ntop(current_addr->ai_addr), port_, get_last_socket_error_message());
close();
continue;
}
logger_->log_info("Connected to %s:%" PRIu16, sockaddr_ntop(current_addr->ai_addr), port_);
}
FD_SET(socket_file_descriptor_, &total_list_);
socket_max_ = socket_file_descriptor_;
return 0;
}
return -1;
}
int8_t Socket::createConnection(const addrinfo *, ip4addr &addr) {
if (!valid_socket(socket_file_descriptor_ = socket(AF_INET, SOCK_STREAM, 0))) {
logger_->log_error("error while connecting to server socket");
return -1;
}
setSocketOptions(socket_file_descriptor_);
if (listeners_ > 0) {
// server socket
sockaddr_in sa{};
memset(&sa, 0, sizeof(struct sockaddr_in));
sa.sin_family = AF_INET;
sa.sin_port = htons(port_);
sa.sin_addr.s_addr = htonl(is_loopback_only_ ? INADDR_LOOPBACK : INADDR_ANY);
if (bind(socket_file_descriptor_, reinterpret_cast<const sockaddr*>(&sa), sizeof(struct sockaddr_in)) == SOCKET_ERROR) {
logger_->log_error("Could not bind to socket, reason %s", get_last_socket_error_message());
return -1;
}
if (listen(socket_file_descriptor_, listeners_) == -1) {
return -1;
}
logger_->log_debug("Created connection with %d listeners", listeners_);
} else {
// client socket
#ifndef WIN32
if (!local_network_interface_.getInterface().empty()) {
const auto err = bind_to_local_network_interface(socket_file_descriptor_, local_network_interface_);
if (err) logger_->log_info("Bind to interface %s failed %s", local_network_interface_.getInterface(), err.message());
else logger_->log_info("Bind to interface %s", local_network_interface_.getInterface());
}
#endif /* !WIN32 */
sockaddr_in sa_loc{};
memset(&sa_loc, 0x00, sizeof(sa_loc));
sa_loc.sin_family = AF_INET;
sa_loc.sin_port = htons(port_);
// use any address if you are connecting to the local machine for testing
// otherwise we must use the requested hostname
if (IsNullOrEmpty(requested_hostname_) || requested_hostname_ == "localhost") {
sa_loc.sin_addr.s_addr = htonl(is_loopback_only_ ? INADDR_LOOPBACK : INADDR_ANY);
} else {
#ifdef WIN32
sa_loc.sin_addr.s_addr = addr.s_addr;
}
if (connect(socket_file_descriptor_, reinterpret_cast<const sockaddr*>(&sa_loc), sizeof(sockaddr_in)) == SOCKET_ERROR) {
int err = WSAGetLastError();
if (err == WSAEADDRNOTAVAIL) {
logger_->log_error("invalid or unknown IP");
} else if (err == WSAECONNREFUSED) {
logger_->log_error("Connection refused");
} else {
logger_->log_error("Unknown error");
}
#else
sa_loc.sin_addr.s_addr = addr;
}
if (connect(socket_file_descriptor_, reinterpret_cast<const sockaddr *>(&sa_loc), sizeof(sockaddr_in)) < 0) {
#endif /* WIN32 */
close();
return -1;
}
}
// add the listener to the total set
FD_SET(socket_file_descriptor_, &total_list_);
socket_max_ = socket_file_descriptor_;
logger_->log_debug("Created connection with file descriptor %d", socket_file_descriptor_);
return 0;
}
int Socket::initialize() {
addrinfo hints{};
memset(&hints, 0, sizeof hints); // make sure the struct is empty
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_CANONNAME;
if (listeners_ > 0 && !is_loopback_only_)
hints.ai_flags = AI_PASSIVE;
hints.ai_protocol = 0; /* any protocol */
const char* const gai_node = [this]() -> const char* {
if (is_loopback_only_) return "localhost";
if (!is_loopback_only_ && listeners_ > 0) return nullptr; // all non-localhost server sockets listen on wildcard address
if (!requested_hostname_.empty()) return requested_hostname_.c_str();
return nullptr;
}();
const auto gai_service = std::to_string(port_);
addrinfo* getaddrinfo_result = nullptr;
const int errcode = getaddrinfo(gai_node, gai_service.c_str(), &hints, &getaddrinfo_result);
const std::unique_ptr<addrinfo, util::addrinfo_deleter> addr_info{ getaddrinfo_result };
getaddrinfo_result = nullptr;
if (errcode != 0) {
logger_->log_error("getaddrinfo: %s", get_last_getaddrinfo_err_str(errcode));
return -1;
}
socket_file_descriptor_ = INVALID_SOCKET;
// AI_CANONNAME always sets ai_canonname of the first addrinfo structure
canonical_hostname_ = !IsNullOrEmpty(addr_info->ai_canonname) ? addr_info->ai_canonname : requested_hostname_;
const auto conn_result = port_ > 0 ? createConnection(addr_info.get()) : -1;
if (conn_result == 0 && nonBlocking_) {
// Put the socket in non-blocking mode:
const auto err = set_non_blocking(socket_file_descriptor_);
if (err) logger_->log_info("Couldn't make socket non-blocking: %s", err.message());
else logger_->log_debug("Successfully applied O_NONBLOCK to fd");
}
return conn_result;
}
int16_t Socket::select_descriptor(const uint16_t msec) {
if (listeners_ == 0) {
return socket_file_descriptor_;
}
struct timeval tv{};
read_fds_ = total_list_;
tv.tv_sec = msec / 1000;
tv.tv_usec = (msec % 1000) * 1000;
std::lock_guard<std::recursive_mutex> guard(selection_mutex_);
if (msec > 0)
select(socket_max_ + 1, &read_fds_, nullptr, nullptr, &tv);
else
select(socket_max_ + 1, &read_fds_, nullptr, nullptr, nullptr);
for (int i = 0; i <= socket_max_; i++) {
if (FD_ISSET(i, &read_fds_)) {
if (i == socket_file_descriptor_) {
if (listeners_ > 0) {
struct sockaddr_storage remoteaddr; // client address
socklen_t addrlen = sizeof remoteaddr;
int newfd = accept(socket_file_descriptor_, (struct sockaddr *) &remoteaddr, &addrlen);
FD_SET(newfd, &total_list_); // add to master set
if (newfd > socket_max_) { // keep track of the max
socket_max_ = newfd;
}
return newfd;
} else {
return socket_file_descriptor_;
}
// we have a new connection
} else {
// data to be received on i
return i;
}
}
}
logger_->log_debug("Could not find a suitable file descriptor or select timed out");
return -1;
}
int16_t Socket::setSocketOptions(const SocketDescriptor sock) {
int opt = 1;
#ifndef WIN32
#ifndef __MACH__
if (setsockopt(sock, SOL_TCP, TCP_NODELAY, static_cast<void*>(&opt), sizeof(opt)) < 0) {
logger_->log_error("setsockopt() TCP_NODELAY failed");
::close(sock);
return -1;
}
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
logger_->log_error("setsockopt() SO_REUSEADDR failed");
::close(sock);
return -1;
}
int sndsize = 256 * 1024;
if (setsockopt(sock, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&sndsize), sizeof(sndsize)) < 0) {
logger_->log_error("setsockopt() SO_SNDBUF failed");
::close(sock);
return -1;
}
#else
if (listeners_ > 0) {
// lose the pesky "address already in use" error message
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char *>(&opt), sizeof(opt)) < 0) {
logger_->log_error("setsockopt() SO_REUSEADDR failed");
::close(sock);
return -1;
}
}
#endif /* !__MACH__ */
#endif /* !WIN32 */
return 0;
}
std::string Socket::getHostname() const {
return canonical_hostname_;
}
// data stream overrides
int Socket::write(const uint8_t *value, int size) {
gsl_Expects(size >= 0);
int ret = 0, bytes = 0;
int fd = select_descriptor(1000);
if (fd < 0) { return -1; }
while (bytes < size) {
ret = send(fd, reinterpret_cast<const char*>(value) + bytes, size - bytes, 0);
// check for errors
if (ret <= 0) {
utils::file::FileUtils::close(fd);
logger_->log_error("Could not send to %d, error: %s", fd, get_last_socket_error_message());
return ret;
}
bytes += ret;
}
if (ret)
logger_->log_trace("Send data size %d over socket %d", size, fd);
total_written_ += bytes;
return bytes;
}
int Socket::read(uint8_t *buf, int buflen, bool retrieve_all_bytes) {
gsl_Expects(buflen >= 0);
int32_t total_read = 0;
while (buflen) {
int16_t fd = select_descriptor(1000);
if (fd < 0) {
if (listeners_ <= 0) {
logger_->log_debug("fd %d close %i", fd, buflen);
utils::file::FileUtils::close(socket_file_descriptor_);
}
return -1;
}
int bytes_read = recv(fd, reinterpret_cast<char*>(buf), buflen, 0);
logger_->log_trace("Recv call %d", bytes_read);
if (bytes_read <= 0) {
if (bytes_read == 0) {
logger_->log_debug("Other side hung up on %d", fd);
} else {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// continue
return -2;
}
logger_->log_error("Could not recv on %d ( port %d), error: %s", fd, port_, strerror(errno));
}
return -1;
}
buflen -= bytes_read;
buf += bytes_read;
total_read += bytes_read;
if (!retrieve_all_bytes) {
break;
}
}
total_read_ += total_read;
return total_read;
}
} /* namespace io */
} /* namespace minifi */
} /* namespace nifi */
} /* namespace apache */
} /* namespace org */