blob: 09a69134d138a1ccaa71bdb0d1bd9265c03b0e63 [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
#include "singa/comm/socket.h"
#include <glog/logging.h>
namespace singa {
#ifdef USE_ZMQ
Poller::Poller() {
poller_ = zpoller_new(nullptr);
}
Poller::Poller(SocketInterface* socket) {
poller_ = zpoller_new(nullptr);
Add(socket);
}
void Poller::Add(SocketInterface* socket) {
zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID());
zpoller_add(poller_, zsock);
zsock2Socket_[zsock] = socket;
}
SocketInterface* Poller::Wait(int timeout) {
zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout));
if (sock != nullptr)
return zsock2Socket_[sock];
else
return nullptr;
}
bool Poller::Terminated() {
return zpoller_terminated(poller_);
}
Dealer::Dealer() : Dealer(-1) {}
Dealer::Dealer(int id) : id_(id) {
dealer_ = zsock_new(ZMQ_DEALER);
CHECK_NOTNULL(dealer_);
}
Dealer::~Dealer() {
zsock_destroy(&dealer_);
}
int Dealer::Connect(const std::string& endpoint) {
CHECK_GT(endpoint.length(), 0);
if (endpoint.length()) {
CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0);
return 1;
}
return 0;
}
int Dealer::Send(Msg** msg) {
zmsg_t* zmsg = (*msg)->DumpToZmsg();
zmsg_send(&zmsg, dealer_);
delete *msg;
*msg = nullptr;
return 1;
}
Msg* Dealer::Receive() {
zmsg_t* zmsg = zmsg_recv(dealer_);
if (zmsg == nullptr)
return nullptr;
Msg* msg = new Msg();
msg->ParseFromZmsg(zmsg);
return msg;
}
void* Dealer::InternalID() const {
return dealer_;
}
Router::Router() : Router(100) {}
Router::Router(int bufsize) {
nBufmsg_ = 0;
bufsize_ = bufsize;
router_ = zsock_new(ZMQ_ROUTER);
CHECK_NOTNULL(router_);
poller_ = zpoller_new(router_);
CHECK_NOTNULL(poller_);
}
Router::~Router() {
zsock_destroy(&router_);
for (auto it : id2addr_)
zframe_destroy(&it.second);
for (auto it : bufmsg_) {
for (auto *msg : it.second)
zmsg_destroy(&msg);
}
}
int Router::Bind(const std::string& endpoint) {
int port = -1;
if (endpoint.length()) {
port = zsock_bind(router_, "%s", endpoint.c_str());
}
CHECK_NE(port, -1) << endpoint;
LOG(INFO) << "bind successfully to " << zsock_endpoint(router_);
return port;
}
int Router::Send(Msg **msg) {
zmsg_t* zmsg = (*msg)->DumpToZmsg();
int dstid = (*msg)->dst();
if (id2addr_.find(dstid) != id2addr_.end()) {
// the connection has already been set up
zframe_t* addr = zframe_dup(id2addr_[dstid]);
zmsg_prepend(zmsg, &addr);
zmsg_send(&zmsg, router_);
} else {
// the connection is not ready, buffer the message
if (bufmsg_.size() == 0)
nBufmsg_ = 0;
bufmsg_[dstid].push_back(zmsg);
++nBufmsg_;
CHECK_LE(nBufmsg_, bufsize_);
}
delete *msg;
*msg = nullptr;
return 1;
}
Msg* Router::Receive() {
zmsg_t* zmsg = zmsg_recv(router_);
if (zmsg == nullptr) {
LOG(ERROR) << "Connection broken!";
exit(0);
}
zframe_t* dealer = zmsg_pop(zmsg);
Msg* msg = new Msg();
msg->ParseFromZmsg(zmsg);
if (id2addr_.find(msg->src()) == id2addr_.end()) {
// new connection, store the sender's identfier and send buffered messages
// for it
id2addr_[msg->src()] = dealer;
if (bufmsg_.find(msg->src()) != bufmsg_.end()) {
for (auto& it : bufmsg_.at(msg->src())) {
zframe_t* addr = zframe_dup(dealer);
zmsg_prepend(it, &addr);
zmsg_send(&it, router_);
}
bufmsg_.erase(msg->src());
}
} else {
zframe_destroy(&dealer);
}
return msg;
}
void* Router::InternalID() const {
return router_;
}
#endif
} // namespace singa