blob: d65d68dd68af3eda806a1f74f09bd66b1086eb8f [file] [log] [blame]
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "io/ClientSocket.h"
#include <netinet/tcp.h>
#include <sys/types.h>
#include <netdb.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/ioctl.h>
#include <net/if.h>
#include <ifaddrs.h>
#include <unistd.h>
#include <cstdio>
#include <memory>
#include <utility>
#include <vector>
#include <cerrno>
#include <iostream>
#include <string>
#include "Exception.h"
#include "io/validation.h"
#include "core/logging/LoggerConfiguration.h"
namespace org {
namespace apache {
namespace nifi {
namespace minifi {
namespace io {
Socket::Socket(const std::shared_ptr<SocketContext>& /*context*/, const std::string &hostname, const uint16_t port, const uint16_t listeners = -1)
: requested_hostname_(hostname),
port_(port),
addr_info_(0),
socket_file_descriptor_(-1),
socket_max_(0),
total_written_(0),
total_read_(0),
is_loopback_only_(false),
listeners_(listeners),
canonical_hostname_(""),
nonBlocking_(false),
logger_(logging::LoggerFactory<Socket>::getLogger()) {
FD_ZERO(&total_list_);
FD_ZERO(&read_fds_);
}
Socket::Socket(const std::shared_ptr<SocketContext>& context, const std::string &hostname, const uint16_t port)
: Socket(context, hostname, port, 0) {
}
Socket::Socket(const Socket &&other)
: requested_hostname_(std::move(other.requested_hostname_)),
port_(std::move(other.port_)),
is_loopback_only_(false),
addr_info_(std::move(other.addr_info_)),
socket_file_descriptor_(other.socket_file_descriptor_),
socket_max_(other.socket_max_.load()),
listeners_(other.listeners_),
total_list_(other.total_list_),
read_fds_(other.read_fds_),
canonical_hostname_(std::move(other.canonical_hostname_)),
nonBlocking_(false),
logger_(std::move(other.logger_)) {
total_written_ = other.total_written_.load();
total_read_ = other.total_read_.load();
}
Socket::~Socket() {
closeStream();
}
void Socket::closeStream() {
if (0 != addr_info_) {
freeaddrinfo(addr_info_);
addr_info_ = 0;
}
if (socket_file_descriptor_ >= 0) {
logging::LOG_DEBUG(logger_) << "Closing " << socket_file_descriptor_;
close(socket_file_descriptor_);
socket_file_descriptor_ = -1;
}
if (total_written_ > 0) {
local_network_interface_.log_write(total_written_);
total_written_ = 0;
}
if (total_read_ > 0) {
local_network_interface_.log_read(total_read_);
total_read_ = 0;
}
}
void Socket::setNonBlocking() {
if (listeners_ <= 0) {
nonBlocking_ = true;
}
}
int8_t Socket::createConnection(const addrinfo *p, in_addr_t &addr) {
if ((socket_file_descriptor_ = socket(p->ai_family, p->ai_socktype, p->ai_protocol)) == -1) {
logger_->log_error("error while connecting to server socket");
return -1;
}
setSocketOptions(socket_file_descriptor_);
if (listeners_ <= 0 && !local_network_interface_.getInterface().empty()) {
// bind to local network interface
ifaddrs* list = NULL;
ifaddrs* item = NULL;
ifaddrs* itemFound = NULL;
int result = getifaddrs(&list);
if (result == 0) {
item = list;
while (item) {
if ((item->ifa_addr != NULL) && (item->ifa_name != NULL) && (AF_INET == item->ifa_addr->sa_family)) {
if (strcmp(item->ifa_name, local_network_interface_.getInterface().c_str()) == 0) {
itemFound = item;
break;
}
}
item = item->ifa_next;
}
if (itemFound != NULL) {
result = bind(socket_file_descriptor_, itemFound->ifa_addr, sizeof(struct sockaddr_in));
if (result < 0)
logger_->log_info("Bind to interface %s failed %s", local_network_interface_.getInterface(), strerror(errno));
else
logger_->log_info("Bind to interface %s", local_network_interface_.getInterface());
}
freeifaddrs(list);
}
}
if (listeners_ > 0) {
struct sockaddr_in *sa_loc = (struct sockaddr_in*) p->ai_addr;
sa_loc->sin_family = AF_INET;
sa_loc->sin_port = htons(port_);
if (is_loopback_only_) {
sa_loc->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
} else {
sa_loc->sin_addr.s_addr = htonl(INADDR_ANY);
}
if (bind(socket_file_descriptor_, p->ai_addr, p->ai_addrlen) == -1) {
logger_->log_error("Could not bind to socket, reason %s", strerror(errno));
return -1;
}
}
{
if (listeners_ <= 0) {
struct sockaddr_in *sa_loc = (struct sockaddr_in*) p->ai_addr;
sa_loc->sin_family = AF_INET;
sa_loc->sin_port = htons(port_);
// use any address if you are connecting to the local machine for testing
// otherwise we must use the requested hostname
if (IsNullOrEmpty(requested_hostname_) || requested_hostname_ == "localhost") {
if (is_loopback_only_) {
sa_loc->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
} else {
sa_loc->sin_addr.s_addr = htonl(INADDR_ANY);
}
} else {
sa_loc->sin_addr.s_addr = addr;
}
if (connect(socket_file_descriptor_, p->ai_addr, p->ai_addrlen) == -1) {
close(socket_file_descriptor_);
socket_file_descriptor_ = -1;
return -1;
}
}
}
// listen
if (listeners_ > 0) {
if (listen(socket_file_descriptor_, listeners_) == -1) {
return -1;
} else {
logger_->log_debug("Created connection with %d listeners", listeners_);
}
}
// add the listener to the total set
FD_SET(socket_file_descriptor_, &total_list_);
socket_max_ = socket_file_descriptor_;
logger_->log_debug("Created connection with file descriptor %d", socket_file_descriptor_);
return 0;
}
int16_t Socket::initialize() {
addrinfo hints = { sizeof(addrinfo) };
memset(&hints, 0, sizeof hints); // make sure the struct is empty
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_CANONNAME;
if (listeners_ > 0)
hints.ai_flags |= AI_PASSIVE;
hints.ai_protocol = 0; /* any protocol */
int errcode = getaddrinfo(requested_hostname_.c_str(), 0, &hints, &addr_info_);
if (errcode != 0) {
logger_->log_error("Saw error during getaddrinfo, error: %s", strerror(errno));
return -1;
}
socket_file_descriptor_ = -1;
in_addr_t addr;
struct hostent *h;
#ifdef __MACH__
h = gethostbyname(requested_hostname_.c_str());
#else
const char *host;
host = requested_hostname_.c_str();
char buf[1024];
struct hostent he;
int hh_errno;
gethostbyname_r(host, &he, buf, sizeof(buf), &h, &hh_errno);
#endif
if (h == nullptr) {
logger_->log_error("hostname not defined for %s", requested_hostname_);
return -1;
}
memcpy(reinterpret_cast<char*>(&addr), h->h_addr_list[0], h->h_length);
auto p = addr_info_;
for (; p != NULL; p = p->ai_next) {
if (IsNullOrEmpty(canonical_hostname_)) {
if (!IsNullOrEmpty(p) && !IsNullOrEmpty(p->ai_canonname))
canonical_hostname_ = p->ai_canonname;
}
// we've successfully connected
if (port_ > 0 && createConnection(p, addr) >= 0) {
// Put the socket in non-blocking mode:
if (nonBlocking_) {
if (fcntl(socket_file_descriptor_, F_SETFL, O_NONBLOCK) < 0) {
// handle error
logger_->log_error("Could not create non blocking to socket", strerror(errno));
} else {
logger_->log_debug("Successfully applied O_NONBLOCK to fd");
}
}
logger_->log_debug("Successfully created connection");
return 0;
break;
}
}
logger_->log_debug("Could not find device for our connection");
return -1;
}
int16_t Socket::select_descriptor(const uint16_t msec) {
if (listeners_ == 0) {
return socket_file_descriptor_;
}
struct timeval tv;
read_fds_ = total_list_;
tv.tv_sec = msec / 1000;
tv.tv_usec = (msec % 1000) * 1000;
std::lock_guard<std::recursive_mutex> guard(selection_mutex_);
if (msec > 0)
select(socket_max_ + 1, &read_fds_, NULL, NULL, &tv);
else
select(socket_max_ + 1, &read_fds_, NULL, NULL, NULL);
for (int i = 0; i <= socket_max_; i++) {
if (FD_ISSET(i, &read_fds_)) {
if (i == socket_file_descriptor_) {
if (listeners_ > 0) {
struct sockaddr_storage remoteaddr; // client address
socklen_t addrlen = sizeof remoteaddr;
int newfd = accept(socket_file_descriptor_, (struct sockaddr *) &remoteaddr, &addrlen);
FD_SET(newfd, &total_list_); // add to master set
if (newfd > socket_max_) { // keep track of the max
socket_max_ = newfd;
}
return newfd;
} else {
return socket_file_descriptor_;
}
// we have a new connection
} else {
// data to be received on i
return i;
}
}
}
logger_->log_debug("Could not find a suitable file descriptor or select timed out");
return -1;
}
int16_t Socket::setSocketOptions(const int sock) {
int opt = 1;
#ifndef __MACH__
if (setsockopt(sock, SOL_TCP, TCP_NODELAY, static_cast<void*>(&opt), sizeof(opt)) < 0) {
logger_->log_error("setsockopt() TCP_NODELAY failed");
close(sock);
return -1;
}
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
logger_->log_error("setsockopt() SO_REUSEADDR failed");
close(sock);
return -1;
}
int sndsize = 256 * 1024;
if (setsockopt(sock, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&sndsize), sizeof(sndsize)) < 0) {
logger_->log_error("setsockopt() SO_SNDBUF failed");
close(sock);
return -1;
}
#else
if (listeners_ > 0) {
// lose the pesky "address already in use" error message
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char *>(&opt), sizeof(opt)) < 0) {
logger_->log_error("setsockopt() SO_REUSEADDR failed");
close(sock);
return -1;
}
}
#endif
return 0;
}
std::string Socket::getHostname() const {
return canonical_hostname_;
}
int Socket::writeData(std::vector<uint8_t> &buf, int buflen) {
if (buflen < 0) {
throw minifi::Exception{ExceptionType::GENERAL_EXCEPTION, "negative buflen"};
}
if (buf.size() < static_cast<size_t>(buflen))
return -1;
return writeData(buf.data(), buflen);
}
// data stream overrides
int Socket::writeData(uint8_t *value, int size) {
int ret = 0, bytes = 0;
int fd = select_descriptor(1000);
while (bytes < size) {
ret = send(fd, value + bytes, size - bytes, 0);
// check for errors
if (ret <= 0) {
close(fd);
logger_->log_error("Could not send to %d, error: %s", fd, strerror(errno));
return ret;
}
bytes += ret;
}
if (ret)
logger_->log_trace("Send data size %d over socket %d", size, fd);
total_written_ += bytes;
return bytes;
}
template<typename T>
inline std::vector<uint8_t> Socket::readBuffer(const T& t) {
std::vector<uint8_t> buf;
buf.resize(sizeof t);
readData(reinterpret_cast<uint8_t *>(&buf[0]), sizeof(t));
return buf;
}
int Socket::write(uint64_t base_value, bool is_little_endian) {
return Serializable::write(base_value, this, is_little_endian);
}
int Socket::write(uint32_t base_value, bool is_little_endian) {
return Serializable::write(base_value, this, is_little_endian);
}
int Socket::write(uint16_t base_value, bool is_little_endian) {
return Serializable::write(base_value, this, is_little_endian);
}
int Socket::read(uint64_t &value, bool is_little_endian) {
auto buf = readBuffer(value);
if (is_little_endian) {
value = ((uint64_t) buf[0] << 56) | ((uint64_t) (buf[1] & 255) << 48) | ((uint64_t) (buf[2] & 255) << 40) | ((uint64_t) (buf[3] & 255) << 32) | ((uint64_t) (buf[4] & 255) << 24)
| ((uint64_t) (buf[5] & 255) << 16) | ((uint64_t) (buf[6] & 255) << 8) | ((uint64_t) (buf[7] & 255) << 0);
} else {
value = ((uint64_t) buf[0] << 0) | ((uint64_t) (buf[1] & 255) << 8) | ((uint64_t) (buf[2] & 255) << 16) | ((uint64_t) (buf[3] & 255) << 24) | ((uint64_t) (buf[4] & 255) << 32)
| ((uint64_t) (buf[5] & 255) << 40) | ((uint64_t) (buf[6] & 255) << 48) | ((uint64_t) (buf[7] & 255) << 56);
}
return sizeof(value);
}
int Socket::read(uint32_t &value, bool is_little_endian) {
auto buf = readBuffer(value);
if (is_little_endian) {
value = (buf[0] << 24) | (buf[1] << 16) | (buf[2] << 8) | buf[3];
} else {
value = buf[0] | buf[1] << 8 | buf[2] << 16 | buf[3] << 24;
}
return sizeof(value);
}
int Socket::read(uint16_t &value, bool is_little_endian) {
auto buf = readBuffer(value);
if (is_little_endian) {
value = (buf[0] << 8) | buf[1];
} else {
value = buf[0] | buf[1] << 8;
}
return sizeof(value);
}
int Socket::readData(std::vector<uint8_t> &buf, int buflen, bool retrieve_all_bytes) {
if (buflen < 0) {
throw minifi::Exception{ExceptionType::GENERAL_EXCEPTION, "negative buflen"};
}
if (buf.size() < static_cast<size_t>(buflen)) {
buf.resize(buflen);
}
return readData(buf.data(), buflen, retrieve_all_bytes);
}
int Socket::readData(uint8_t *buf, int buflen, bool retrieve_all_bytes) {
int32_t total_read = 0;
while (buflen) {
int16_t fd = select_descriptor(1000);
if (fd < 0) {
if (listeners_ <= 0) {
logger_->log_debug("fd %d close %i", fd, buflen);
close(socket_file_descriptor_);
}
return -1;
}
int bytes_read = recv(fd, buf, buflen, 0);
logger_->log_trace("Recv call %d", bytes_read);
if (bytes_read <= 0) {
if (bytes_read == 0) {
logger_->log_debug("Other side hung up on %d", fd);
} else {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// continue
return -2;
}
logger_->log_error("Could not recv on %d ( port %d), error: %s", fd, port_, strerror(errno));
}
return -1;
}
buflen -= bytes_read;
buf += bytes_read;
total_read += bytes_read;
if (!retrieve_all_bytes) {
break;
}
}
total_read_ += total_read;
return total_read;
}
} /* namespace io */
} /* namespace minifi */
} /* namespace nifi */
} /* namespace apache */
} /* namespace org */