blob: fafff97b25226e360191bb6338226a8c852ab59d [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file socket.h
* \brief this file aims to provide a wrapper of sockets
* \author Tianqi Chen
*/
#ifndef TVM_COMMON_SOCKET_H_
#define TVM_COMMON_SOCKET_H_
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
using ssize_t = int;
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif
#else
#include <fcntl.h>
#include <netdb.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#endif
#include <dmlc/logging.h>
#include <string>
#include <cstring>
namespace tvm {
namespace common {
/*!
* \brief Get current host name.
* \return The hostname.
*/
inline std::string GetHostName() {
std::string buf; buf.resize(256);
CHECK_NE(gethostname(&buf[0], 256), -1);
return std::string(buf.c_str());
}
/*!
* \brief Common data structure for network address.
*/
struct SockAddr {
sockaddr_storage addr;
SockAddr() {}
/*!
* \brief construct address by url and port
* \param url The url of the address
* \param port The port of the address.
*/
SockAddr(const char *url, int port) {
this->Set(url, port);
}
/*!
* \brief set the address
* \param host the url of the address
* \param port the port of address
*/
void Set(const char *host, int port) {
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = PF_UNSPEC;
hints.ai_flags = AI_PASSIVE;
hints.ai_protocol = SOCK_STREAM;
addrinfo *res = NULL;
int sig = getaddrinfo(host, NULL, &hints, &res);
CHECK(sig == 0 && res != NULL)
<< "cannot obtain address of " << host;
switch (res->ai_family) {
case AF_INET: {
sockaddr_in *addr4 = reinterpret_cast<sockaddr_in *>(&addr);
memcpy(addr4, res->ai_addr, res->ai_addrlen);
addr4->sin_port = htons(port);
addr4->sin_family = AF_INET;
}
break;
case AF_INET6: {
sockaddr_in6 *addr6 = reinterpret_cast<sockaddr_in6 *>(&addr);
memcpy(addr6, res->ai_addr, res->ai_addrlen);
addr6->sin6_port = htons(port);
addr6->sin6_family = AF_INET6;
}
break;
default:
CHECK(false) << "cannot decode address";
}
freeaddrinfo(res);
}
/*! \brief return port of the address */
int port() const {
return ntohs((addr.ss_family == AF_INET6)? \
reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_port : \
reinterpret_cast<const sockaddr_in *>(&addr)->sin_port);
}
/*! \brief return the ip address family */
int ss_family() const {
return addr.ss_family;
}
/*! \return a string representation of the address */
std::string AsString() const {
std::string buf; buf.resize(256);
const void *sinx_addr = nullptr;
if (addr.ss_family == AF_INET6) {
const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_addr;
sinx_addr = reinterpret_cast<const void *>(&addr6);
} else if (addr.ss_family == AF_INET) {
const in_addr& addr4 = reinterpret_cast<const sockaddr_in *>(&addr)->sin_addr;
sinx_addr = reinterpret_cast<const void *>(&addr4);
} else {
CHECK(false) << "illegal address";
}
#ifdef _WIN32
const char *s = inet_ntop(addr.ss_family, sinx_addr,
&buf[0], buf.length());
#else
const char *s = inet_ntop(addr.ss_family, sinx_addr,
&buf[0], static_cast<socklen_t>(buf.length()));
#endif
CHECK(s != nullptr) << "cannot decode address";
std::ostringstream os;
os << s << ":" << port();
return os.str();
}
};
/*!
* \brief base class containing common operations of TCP and UDP sockets
*/
class Socket {
public:
#if defined(_WIN32)
using sock_size_t = int;
using SockType = SOCKET;
#else
using SockType = int;
using sock_size_t = size_t;
static constexpr int INVALID_SOCKET = -1;
#endif
/*! \brief the file descriptor of socket */
SockType sockfd;
/*!
* \brief set this socket to use non-blocking mode
* \param non_block whether set it to be non-block, if it is false
* it will set it back to block mode
*/
void SetNonBlock(bool non_block) {
#ifdef _WIN32
u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock");
}
#else
int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) {
Socket::Error("SetNonBlock-1");
}
if (non_block) {
flag |= O_NONBLOCK;
} else {
flag &= ~O_NONBLOCK;
}
if (fcntl(sockfd, F_SETFL, flag) == -1) {
Socket::Error("SetNonBlock-2");
}
#endif
}
/*!
* \brief bind the socket to an address
* \param addr The address to be binded
*/
void Bind(const SockAddr &addr) {
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
(addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
sizeof(sockaddr_in))) == -1) {
Socket::Error("Bind");
}
}
/*!
* \brief try bind the socket to host, from start_port to end_port
* \param start_port starting port number to try
* \param end_port ending port number to try
* \return the port successfully bind to, return -1 if failed to bind any port
*/
inline int TryBindHost(int start_port, int end_port) {
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
(addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
sizeof(sockaddr_in))) == 0) {
return port;
}
#if defined(_WIN32)
if (WSAGetLastError() != WSAEADDRINUSE) {
Socket::Error("TryBindHost");
}
#else
if (errno != EADDRINUSE) {
Socket::Error("TryBindHost");
}
#endif
}
return -1;
}
/*! \brief get last error code if any */
int GetSockError() const {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
Error("GetSockError");
}
return error;
}
/*! \brief check if anything bad happens */
bool BadSocket() const {
if (IsClosed()) return true;
int err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
/*! \brief check if socket is already closed */
bool IsClosed() const {
return sockfd == INVALID_SOCKET;
}
/*! \brief close the socket */
void Close() {
if (sockfd != INVALID_SOCKET) {
#ifdef _WIN32
closesocket(sockfd);
#else
close(sockfd);
#endif
sockfd = INVALID_SOCKET;
} else {
Error("Socket::Close double close the socket or close without create");
}
}
/*!
* \return last error of socket 2operation
*/
static int GetLastError() {
#ifdef _WIN32
return WSAGetLastError();
#else
return errno;
#endif
}
/*! \return whether last error was would block */
static bool LastErrorWouldBlock() {
int errsv = GetLastError();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
#endif
}
/*!
* \brief start up the socket module
* call this before using the sockets
*/
static void Startup() {
#ifdef _WIN32
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
LOG(FATAL) << "Could not find a usable version of Winsock.dll";
}
#endif
}
/*!
* \brief shutdown the socket module after use, all sockets need to be closed
*/
static void Finalize() {
#ifdef _WIN32
WSACleanup();
#endif
}
/*!
* \brief Report an socket error.
* \param msg The error message.
*/
static void Error(const char *msg) {
int errsv = GetLastError();
#ifdef _WIN32
LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
#else
LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv);
#endif
}
protected:
explicit Socket(SockType sockfd) : sockfd(sockfd) {
}
};
/*!
* \brief a wrapper of TCP socket that hopefully be cross platform
*/
class TCPSocket : public Socket {
public:
TCPSocket() : Socket(INVALID_SOCKET) {
}
/*!
* \brief construct a TCP socket from existing descriptor
* \param sockfd The descriptor
*/
explicit TCPSocket(SockType sockfd) : Socket(sockfd) {
}
/*!
* \brief enable/disable TCP keepalive
* \param keepalive whether to set the keep alive option on
*/
void SetKeepAlive(bool keepalive) {
int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive");
}
}
/*!
* \brief create the socket, call this before using socket
* \param af domain
*/
void Create(int af = PF_INET) {
sockfd = socket(af, SOCK_STREAM, 0);
if (sockfd == INVALID_SOCKET) {
Socket::Error("Create");
}
}
/*!
* \brief perform listen of the socket
* \param backlog backlog parameter
*/
void Listen(int backlog = 16) {
listen(sockfd, backlog);
}
/*!
* \brief get a new connection
* \return The accepted socket connection.
*/
TCPSocket Accept() {
SockType newfd = accept(sockfd, NULL, NULL);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
return TCPSocket(newfd);
}
/*!
* \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured
*/
int AtMark() const {
#ifdef _WIN32
unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
#else
int atmark;
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif
return static_cast<int>(atmark);
}
/*!
* \brief connect to an address
* \param addr the address to connect to
* \return whether connect is successful
*/
bool Connect(const SockAddr &addr) {
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
(addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
sizeof(sockaddr_in))) == 0;
}
/*!
* \brief send data using the socket
* \param buf_ the pointer to the buffer
* \param len the size of the buffer
* \param flag extra flags
* \return size of data actually sent
* return -1 if error occurs
*/
ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_);
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
}
/*!
* \brief receive data using the socket
* \param buf_ the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually received
* return -1 if error occurs
*/
ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_);
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
}
/*!
* \brief peform block write that will attempt to send all data out
* can still return smaller than request when error occurs
* \param buf_ the pointer to the buffer
* \param len the size of the buffer
* \return size of data actually sent
*/
size_t SendAll(const void *buf_, size_t len) {
const char *buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
if (ret == -1) {
if (LastErrorWouldBlock()) return ndone;
Socket::Error("SendAll");
}
buf += ret;
ndone += ret;
}
return ndone;
}
/*!
* \brief peform block read that will attempt to read all data
* can still return smaller than request when error occurs
* \param buf_ the buffer pointer
* \param len length of data to recv
* \return size of data actually sent
*/
size_t RecvAll(void *buf_, size_t len) {
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) {
if (LastErrorWouldBlock()) {
LOG(FATAL) << "would block";
return ndone;
}
Socket::Error("RecvAll");
}
if (ret == 0) return ndone;
buf += ret;
ndone += ret;
}
return ndone;
}
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_SOCKET_H_