| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file socket.h |
| * \brief this file aims to provide a wrapper of sockets |
| * \author Tianqi Chen |
| */ |
| #ifndef TVM_SUPPORT_SOCKET_H_ |
| #define TVM_SUPPORT_SOCKET_H_ |
| |
| #if defined(_WIN32) |
| #define NOMINMAX |
| #include <winsock2.h> |
| #include <ws2tcpip.h> |
| #undef NOMINMAX |
| using ssize_t = int; |
| #ifdef _MSC_VER |
| #pragma comment(lib, "Ws2_32.lib") |
| #endif |
| #else |
| #include <arpa/inet.h> |
| #include <errno.h> |
| #include <fcntl.h> |
| #include <netdb.h> |
| #include <netinet/in.h> |
| #include <sys/ioctl.h> |
| #include <sys/select.h> |
| #include <sys/socket.h> |
| #include <unistd.h> |
| #endif |
| #include <dmlc/logging.h> |
| |
| #include <cstring> |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include "../support/util.h" |
| |
| #if defined(_WIN32) |
| static inline int poll(struct pollfd* pfd, int nfds, int timeout) { |
| return WSAPoll(pfd, nfds, timeout); |
| } |
| #else |
| #include <sys/poll.h> |
| #endif // defined(_WIN32) |
| |
| namespace tvm { |
| namespace support { |
| /*! |
| * \brief Get current host name. |
| * \return The hostname. |
| */ |
| inline std::string GetHostName() { |
| std::string buf; |
| buf.resize(256); |
| CHECK_NE(gethostname(&buf[0], 256), -1); |
| return std::string(buf.c_str()); |
| } |
| |
| /*! |
| * \brief ValidateIP validates an ip address. |
| * \param ip The ip address in string format localhost or x.x.x.x format |
| * \return result of operation. |
| */ |
| inline bool ValidateIP(std::string ip) { |
| if (ip == "localhost") { |
| return true; |
| } |
| struct sockaddr_in sa_ipv4; |
| struct sockaddr_in6 sa_ipv6; |
| bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr)); |
| bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr)); |
| return is_ipv4 || is_ipv6; |
| } |
| |
| /*! |
| * \brief Common data structure for network address. |
| */ |
| struct SockAddr { |
| sockaddr_storage addr; |
| SockAddr() {} |
| /*! |
| * \brief construct address by url and port |
| * \param url The url of the address |
| * \param port The port of the address. |
| */ |
| SockAddr(const char* url, int port) { this->Set(url, port); } |
| |
| /*! |
| * \brief SockAddr Get the socket address from tracker. |
| * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) |
| * \return SockAddr parsed from url. |
| */ |
| explicit SockAddr(const std::string& url) { |
| size_t sep = url.find(","); |
| std::string host = url.substr(2, sep - 3); |
| std::string port = url.substr(sep + 1, url.length() - 1); |
| CHECK(ValidateIP(host)) << "Url address is not valid " << url; |
| if (host == "localhost") { |
| host = "127.0.0.1"; |
| } |
| this->Set(host.c_str(), std::stoi(port)); |
| } |
| |
| /*! |
| * \brief set the address |
| * \param host the url of the address |
| * \param port the port of address |
| */ |
| void Set(const char* host, int port) { |
| addrinfo hints; |
| memset(&hints, 0, sizeof(hints)); |
| hints.ai_family = PF_UNSPEC; |
| hints.ai_flags = AI_PASSIVE; |
| hints.ai_socktype = SOCK_STREAM; |
| addrinfo* res = nullptr; |
| int sig = getaddrinfo(host, nullptr, &hints, &res); |
| CHECK(sig == 0 && res != nullptr) << "cannot obtain address of " << host; |
| switch (res->ai_family) { |
| case AF_INET: { |
| sockaddr_in* addr4 = reinterpret_cast<sockaddr_in*>(&addr); |
| memcpy(addr4, res->ai_addr, res->ai_addrlen); |
| addr4->sin_port = htons(port); |
| addr4->sin_family = AF_INET; |
| } break; |
| case AF_INET6: { |
| sockaddr_in6* addr6 = reinterpret_cast<sockaddr_in6*>(&addr); |
| memcpy(addr6, res->ai_addr, res->ai_addrlen); |
| addr6->sin6_port = htons(port); |
| addr6->sin6_family = AF_INET6; |
| } break; |
| default: |
| CHECK(false) << "cannot decode address"; |
| } |
| freeaddrinfo(res); |
| } |
| /*! \brief return port of the address */ |
| int port() const { |
| return ntohs((addr.ss_family == AF_INET6) |
| ? reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_port |
| : reinterpret_cast<const sockaddr_in*>(&addr)->sin_port); |
| } |
| /*! \brief return the ip address family */ |
| int ss_family() const { return addr.ss_family; } |
| /*! \return a string representation of the address */ |
| std::string AsString() const { |
| std::string buf; |
| buf.resize(256); |
| |
| const void* sinx_addr = nullptr; |
| if (addr.ss_family == AF_INET6) { |
| const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_addr; |
| sinx_addr = reinterpret_cast<const void*>(&addr6); |
| } else if (addr.ss_family == AF_INET) { |
| const in_addr& addr4 = reinterpret_cast<const sockaddr_in*>(&addr)->sin_addr; |
| sinx_addr = reinterpret_cast<const void*>(&addr4); |
| } else { |
| CHECK(false) << "illegal address"; |
| } |
| |
| #ifdef _WIN32 |
| const char* s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) |
| &buf[0], buf.length()); |
| #else |
| const char* s = |
| inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast<socklen_t>(buf.length())); |
| #endif |
| CHECK(s != nullptr) << "cannot decode address"; |
| std::ostringstream os; |
| os << s << ":" << port(); |
| return os.str(); |
| } |
| }; |
| /*! |
| * \brief base class containing common operations of TCP and UDP sockets |
| */ |
| class Socket { |
| public: |
| #if defined(_WIN32) |
| using sock_size_t = int; |
| using SockType = SOCKET; |
| #else |
| using SockType = int; |
| using sock_size_t = size_t; |
| static constexpr int INVALID_SOCKET = -1; |
| #endif |
| /*! \brief the file descriptor of socket */ |
| SockType sockfd; |
| /*! |
| * \brief set this socket to use non-blocking mode |
| * \param non_block whether set it to be non-block, if it is false |
| * it will set it back to block mode |
| */ |
| void SetNonBlock(bool non_block) { |
| #ifdef _WIN32 |
| u_long mode = non_block ? 1 : 0; |
| if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { |
| Socket::Error("SetNonBlock"); |
| } |
| #else |
| int flag = fcntl(sockfd, F_GETFL, 0); |
| if (flag == -1) { |
| Socket::Error("SetNonBlock-1"); |
| } |
| if (non_block) { |
| flag |= O_NONBLOCK; |
| } else { |
| flag &= ~O_NONBLOCK; |
| } |
| if (fcntl(sockfd, F_SETFL, flag) == -1) { |
| Socket::Error("SetNonBlock-2"); |
| } |
| #endif |
| } |
| /*! |
| * \brief bind the socket to an address |
| * \param addr The address to be binded |
| */ |
| void Bind(const SockAddr& addr) { |
| if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), |
| (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == |
| -1) { |
| Socket::Error("Bind"); |
| } |
| } |
| /*! |
| * \brief try bind the socket to host, from start_port to end_port |
| * \param host host address to bind the socket |
| * \param start_port starting port number to try |
| * \param end_port ending port number to try |
| * \return the port successfully bind to, return -1 if failed to bind any port |
| */ |
| inline int TryBindHost(std::string host, int start_port, int end_port) { |
| for (int port = start_port; port < end_port; ++port) { |
| SockAddr addr(host.c_str(), port); |
| if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr), |
| (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == |
| 0) { |
| return port; |
| } else { |
| LOG(WARNING) << "Bind failed to " << host << ":" << port; |
| } |
| #if defined(_WIN32) |
| if (WSAGetLastError() != WSAEADDRINUSE) { |
| Socket::Error("TryBindHost"); |
| } |
| #else |
| if (errno != EADDRINUSE) { |
| Socket::Error("TryBindHost"); |
| } |
| #endif |
| } |
| return -1; |
| } |
| /*! \brief get last error code if any */ |
| int GetSockError() const { |
| int error = 0; |
| socklen_t len = sizeof(error); |
| if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) { |
| Error("GetSockError"); |
| } |
| return error; |
| } |
| /*! \brief check if anything bad happens */ |
| bool BadSocket() const { |
| if (IsClosed()) return true; |
| int err = GetSockError(); |
| if (err == EBADF || err == EINTR) return true; |
| return false; |
| } |
| /*! \brief check if socket is already closed */ |
| bool IsClosed() const { return sockfd == INVALID_SOCKET; } |
| /*! \brief close the socket */ |
| void Close() { |
| if (sockfd != INVALID_SOCKET) { |
| #ifdef _WIN32 |
| closesocket(sockfd); |
| #else |
| close(sockfd); |
| #endif |
| sockfd = INVALID_SOCKET; |
| } else { |
| Error("Socket::Close double close the socket or close without create"); |
| } |
| } |
| /*! |
| * \return last error of socket 2operation |
| */ |
| static int GetLastError() { |
| #ifdef _WIN32 |
| return WSAGetLastError(); |
| #else |
| return errno; |
| #endif |
| } |
| /*! \return whether last error was would block */ |
| static bool LastErrorWouldBlock() { |
| int errsv = GetLastError(); |
| #ifdef _WIN32 |
| return errsv == WSAEWOULDBLOCK; |
| #else |
| return errsv == EAGAIN || errsv == EWOULDBLOCK; |
| #endif |
| } |
| /*! |
| * \brief start up the socket module |
| * call this before using the sockets |
| */ |
| static void Startup() { |
| #ifdef _WIN32 |
| WSADATA wsa_data; |
| if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { |
| Socket::Error("Startup"); |
| } |
| if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { |
| WSACleanup(); |
| LOG(FATAL) << "Could not find a usable version of Winsock.dll"; |
| } |
| #endif |
| } |
| /*! |
| * \brief shutdown the socket module after use, all sockets need to be closed |
| */ |
| static void Finalize() { |
| #ifdef _WIN32 |
| WSACleanup(); |
| #endif |
| } |
| /*! |
| * \brief Report an socket error. |
| * \param msg The error message. |
| */ |
| static void Error(const char* msg) { |
| int errsv = GetLastError(); |
| #ifdef _WIN32 |
| LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; |
| #else |
| LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv); |
| #endif |
| } |
| |
| protected: |
| explicit Socket(SockType sockfd) : sockfd(sockfd) {} |
| }; |
| |
| /*! |
| * \brief a wrapper of TCP socket that hopefully be cross platform |
| */ |
| class TCPSocket : public Socket { |
| public: |
| TCPSocket() : Socket(INVALID_SOCKET) {} |
| /*! |
| * \brief construct a TCP socket from existing descriptor |
| * \param sockfd The descriptor |
| */ |
| explicit TCPSocket(SockType sockfd) : Socket(sockfd) {} |
| /*! |
| * \brief enable/disable TCP keepalive |
| * \param keepalive whether to set the keep alive option on |
| */ |
| void SetKeepAlive(bool keepalive) { |
| int opt = static_cast<int>(keepalive); |
| if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) < |
| 0) { |
| Socket::Error("SetKeepAlive"); |
| } |
| } |
| /*! |
| * \brief create the socket, call this before using socket |
| * \param af domain |
| */ |
| void Create(int af = PF_INET) { |
| sockfd = socket(af, SOCK_STREAM, 0); |
| if (sockfd == INVALID_SOCKET) { |
| Socket::Error("Create"); |
| } |
| } |
| /*! |
| * \brief perform listen of the socket |
| * \param backlog backlog parameter |
| */ |
| void Listen(int backlog = 16) { listen(sockfd, backlog); } |
| /*! |
| * \brief get a new connection |
| * \return The accepted socket connection. |
| */ |
| TCPSocket Accept() { |
| SockType newfd = accept(sockfd, nullptr, nullptr); |
| if (newfd == INVALID_SOCKET) { |
| Socket::Error("Accept"); |
| } |
| return TCPSocket(newfd); |
| } |
| /*! |
| * \brief get a new connection |
| * \param addr client address from which connection accepted |
| * \return The accepted socket connection. |
| */ |
| TCPSocket Accept(SockAddr* addr) { |
| socklen_t addrlen = sizeof(addr->addr); |
| SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); |
| if (newfd == INVALID_SOCKET) { |
| Socket::Error("Accept"); |
| } |
| return TCPSocket(newfd); |
| } |
| /*! |
| * \brief decide whether the socket is at OOB mark |
| * \return 1 if at mark, 0 if not, -1 if an error occurred |
| */ |
| int AtMark() const { |
| #ifdef _WIN32 |
| unsigned long atmark; // NOLINT(*) |
| if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; |
| #else |
| int atmark; |
| if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; |
| #endif |
| return static_cast<int>(atmark); |
| } |
| /*! |
| * \brief connect to an address |
| * \param addr the address to connect to |
| * \return whether connect is successful |
| */ |
| bool Connect(const SockAddr& addr) { |
| return connect( |
| sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), |
| (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0; |
| } |
| /*! |
| * \brief send data using the socket |
| * \param buf_ the pointer to the buffer |
| * \param len the size of the buffer |
| * \param flag extra flags |
| * \return size of data actually sent |
| * return -1 if error occurs |
| */ |
| ssize_t Send(const void* buf_, size_t len, int flag = 0) { |
| const char* buf = reinterpret_cast<const char*>(buf_); |
| return send(sockfd, buf, static_cast<sock_size_t>(len), flag); |
| } |
| /*! |
| * \brief receive data using the socket |
| * \param buf_ the pointer to the buffer |
| * \param len the size of the buffer |
| * \param flags extra flags |
| * \return size of data actually received |
| * return -1 if error occurs |
| */ |
| ssize_t Recv(void* buf_, size_t len, int flags = 0) { |
| char* buf = reinterpret_cast<char*>(buf_); |
| return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); |
| } |
| /*! |
| * \brief peform block write that will attempt to send all data out |
| * can still return smaller than request when error occurs |
| * \param buf_ the pointer to the buffer |
| * \param len the size of the buffer |
| * \return size of data actually sent |
| */ |
| size_t SendAll(const void* buf_, size_t len) { |
| const char* buf = reinterpret_cast<const char*>(buf_); |
| size_t ndone = 0; |
| while (ndone < len) { |
| ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); |
| if (ret == -1) { |
| if (LastErrorWouldBlock()) return ndone; |
| Socket::Error("SendAll"); |
| } |
| buf += ret; |
| ndone += ret; |
| } |
| return ndone; |
| } |
| /*! |
| * \brief peform block read that will attempt to read all data |
| * can still return smaller than request when error occurs |
| * \param buf_ the buffer pointer |
| * \param len length of data to recv |
| * \return size of data actually sent |
| */ |
| size_t RecvAll(void* buf_, size_t len) { |
| char* buf = reinterpret_cast<char*>(buf_); |
| size_t ndone = 0; |
| while (ndone < len) { |
| ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); |
| if (ret == -1) { |
| if (LastErrorWouldBlock()) { |
| LOG(FATAL) << "would block"; |
| return ndone; |
| } |
| Socket::Error("RecvAll"); |
| } |
| if (ret == 0) return ndone; |
| buf += ret; |
| ndone += ret; |
| } |
| return ndone; |
| } |
| /*! |
| * \brief Send the data to remote. |
| * \param data The data to be sent. |
| */ |
| void SendBytes(std::string data) { |
| int datalen = data.length(); |
| CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen)); |
| CHECK_EQ(SendAll(data.c_str(), datalen), datalen); |
| } |
| /*! |
| * \brief Receive the data to remote. |
| * \return The data received. |
| */ |
| std::string RecvBytes() { |
| int datalen = 0; |
| CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen)); |
| std::string data; |
| data.resize(datalen); |
| CHECK_EQ(RecvAll(&data[0], datalen), datalen); |
| return data; |
| } |
| }; |
| |
| /*! \brief helper data structure to perform poll */ |
| struct PollHelper { |
| public: |
| /*! |
| * \brief add file descriptor to watch for read |
| * \param fd file descriptor to be watched |
| */ |
| inline void WatchRead(TCPSocket::SockType fd) { |
| auto& pfd = fds[fd]; |
| pfd.fd = fd; |
| pfd.events |= POLLIN; |
| } |
| /*! |
| * \brief add file descriptor to watch for write |
| * \param fd file descriptor to be watched |
| */ |
| inline void WatchWrite(TCPSocket::SockType fd) { |
| auto& pfd = fds[fd]; |
| pfd.fd = fd; |
| pfd.events |= POLLOUT; |
| } |
| /*! |
| * \brief add file descriptor to watch for exception |
| * \param fd file descriptor to be watched |
| */ |
| inline void WatchException(TCPSocket::SockType fd) { |
| auto& pfd = fds[fd]; |
| pfd.fd = fd; |
| pfd.events |= POLLPRI; |
| } |
| /*! |
| * \brief Check if the descriptor is ready for read |
| * \param fd file descriptor to check status |
| */ |
| inline bool CheckRead(TCPSocket::SockType fd) const { |
| const auto& pfd = fds.find(fd); |
| return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); |
| } |
| /*! |
| * \brief Check if the descriptor is ready for write |
| * \param fd file descriptor to check status |
| */ |
| inline bool CheckWrite(TCPSocket::SockType fd) const { |
| const auto& pfd = fds.find(fd); |
| return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); |
| } |
| /*! |
| * \brief Check if the descriptor has any exception |
| * \param fd file descriptor to check status |
| */ |
| inline bool CheckExcept(TCPSocket::SockType fd) const { |
| const auto& pfd = fds.find(fd); |
| return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); |
| } |
| /*! |
| * \brief wait for exception event on a single descriptor |
| * \param fd the file descriptor to wait the event for |
| * \param timeout the timeout counter, can be negative, which means wait until the event happen |
| * \return 1 if success, 0 if timeout, and -1 if error occurs |
| */ |
| inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) |
| pollfd pfd; |
| pfd.fd = fd; |
| pfd.events = POLLPRI; |
| return poll(&pfd, 1, timeout); |
| } |
| |
| /*! |
| * \brief peform poll on the set defined, read, write, exception |
| * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block |
| * \return |
| */ |
| inline void Poll(long timeout = -1) { // NOLINT(*) |
| std::vector<pollfd> fdset; |
| fdset.reserve(fds.size()); |
| for (auto kv : fds) { |
| fdset.push_back(kv.second); |
| } |
| int ret = poll(fdset.data(), fdset.size(), timeout); |
| if (ret == -1) { |
| Socket::Error("Poll"); |
| } else { |
| for (auto& pfd : fdset) { |
| auto revents = pfd.revents & pfd.events; |
| if (!revents) { |
| fds.erase(pfd.fd); |
| } else { |
| fds[pfd.fd].events = revents; |
| } |
| } |
| } |
| } |
| |
| std::unordered_map<TCPSocket::SockType, pollfd> fds; |
| }; |
| |
| } // namespace support |
| } // namespace tvm |
| #endif // TVM_SUPPORT_SOCKET_H_ |