blob: c06128c2a0ff6af8dea7b20dce3581c55e5aad8e [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
#include "singa/stub.h"
#include <glog/logging.h>
#include <unistd.h>
#include <map>
#include <thread>
#include <set>
#include "singa/proto/common.pb.h"
#include "singa/utils/cluster.h"
#include "singa/utils/common.h"
#include "singa/utils/tinydir.h"
#include "singa/utils/math_blob.h"
namespace singa {
using std::vector;
using std::string;
/***********************Stub****************************/
Stub::~Stub() {
delete router_;
}
void Stub::Setup() {
router_ = new Router();
router_->Bind(kInprocRouterEndpoint);
auto cluster = Cluster::Get();
const string hostip = cluster->hostip();
int port = router_->Bind("tcp://" + hostip + ":*");
endpoint_ = hostip + ":" + std::to_string(port);
}
/**
* Get a hash id for a Param object from a group.
*
* Simple multiple group_id with a large prime number 997 (assuming there are
* no more than 997 worker groups) and plus owner param id.
*/
inline int Hash(int grp_id, int param_id) {
return grp_id * 997 + param_id;
}
const std::unordered_map<int, ParamEntry*> CreateParamShard(
const vector<Worker*>& workers) {
std::unordered_map<int, ParamEntry*> shard;
// grp id -> net
std::unordered_map<int, NeuralNet*> grp2net;
// grp id -> worker id range
std::unordered_map<int, std::pair<int, int>> grp2workers;
for (auto worker : workers) {
int grp = worker->grp_id(), id = worker->id();
if (grp2net.find(grp) == grp2net.end()) {
grp2net[grp] = worker->train_net();
grp2workers[grp] = std::make_pair(id, id + 1);
} else {
CHECK_EQ(grp2net[grp], worker->train_net());
int start = grp2workers[grp].first, end = grp2workers[grp].second;
if (start > id) start = id;
if (end < id + 1) end = id + 1;
grp2workers[grp] = std::make_pair(start, end);
}
}
for (const auto entry : grp2net) {
int grp = entry.first;
int wstart = grp2workers[grp].first, wend = grp2workers[grp].second;
for (auto layer : entry.second->layers()) {
if (layer->unroll_index() > 0)
continue;
int partition = layer->partition_id();
bool local = partition >= wstart && partition < wend;
for (auto param : layer->GetParams()) {
int hash = Hash(grp, param->owner());
if (shard.find(hash) == shard.end())
shard[hash] = new ParamEntry();
shard[hash]->AddParam(local, param);
}
}
}
return shard;
}
void Stub::Run(const vector<int>& slice2server,
const vector<Worker*>& workers, const vector<Server*>& servers) {
slice2server_ = slice2server;
int nworkers = workers.size(), nservers = servers.size();
auto cluster = Cluster::Get();
int procs_id = cluster->procs_id();
LOG(INFO) << "Stub in process " << procs_id << " starts";
auto shard = CreateParamShard(workers);
std::map<int, Dealer*> inter_dealers; // for sending msg to other procs
std::queue<Msg*> msg_queue;
while (true) {
Msg* msg = nullptr;
if (msg_queue.empty()) {
msg = router_->Receive();
} else {
msg = msg_queue.front();
msg_queue.pop();
}
int type = msg->type(), dst = msg->dst(), flag = AddrType(dst);
if (flag == kStub && (AddrProc(dst) == procs_id || AddrGrp(dst) == -1)) {
// the following statements are ordered!
if (type == kConnect) {
DeleteMsg(&msg);
} else if (type == kStop) {
int src_flag = AddrType(msg->src());
if (src_flag == kServer) nservers--;
else if (src_flag == kWorkerParam) nworkers--;
DeleteMsg(&msg);
if (nworkers == 0 && nservers == 0) break;
} else {
int grp;
int paramid = ParamID(msg->trgt_val());
ParamEntry *entry = nullptr;
switch (type) {
case kUpdate:
grp = AddrGrp(msg->src());
entry = shard.at(Hash(grp, paramid));
for (auto update_msg : HandleUpdateRequest(entry, &msg))
msg_queue.push(update_msg);
break;
case kRUpdate:
grp = AddrGrp(msg->dst());
entry = shard.at(Hash(grp, paramid));
HandleUpdateResponse(entry, &msg);
break;
case kGet:
grp = AddrGrp(msg->src());
entry = shard.at(Hash(grp, paramid));
for (auto get_msg : HandleGetRequest(entry, &msg))
msg_queue.push(get_msg);
break;
case kRGet:
grp = AddrGrp(msg->dst());
entry = shard.at(Hash(grp, paramid));
HandleGetResponse(entry, &msg);
break;
case kPut:
grp = AddrGrp(msg->src());
entry = shard.at(Hash(grp, paramid));
for (auto put_msg : HandlePutRequest(entry, &msg))
msg_queue.push(put_msg);
break;
default:
LOG(ERROR) << "Unknow message type:" << type;
break;
}
}
} else {
int dst_procs = AddrProc(dst);
if (flag != kStub)
dst_procs = cluster->ProcsIDOf(AddrGrp(dst), AddrID(dst), flag);
if (dst_procs != procs_id) {
if (inter_dealers.find(dst_procs) == inter_dealers.end())
inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs);
inter_dealers[dst_procs]->Send(&msg);
} else {
router_->Send(&msg);
}
}
}
LOG(ERROR) << "Stub in process " << procs_id << " stops";
for (auto& entry : inter_dealers)
delete entry.second;
}
Dealer* Stub::CreateInterProcsDealer(int dst_procs) {
// forward to other procs
auto cluster = Cluster::Get();
auto dealer = new Dealer();
while (cluster->endpoint(dst_procs) == "") {
// kCollectSleepTime));
std::this_thread::sleep_for(std::chrono::milliseconds(3000));
LOG(ERROR) << "waiting for procs " << dst_procs << " to register";
}
dealer->Connect("tcp://"+cluster->endpoint(dst_procs));
return dealer;
}
void Stub::GenMsgs(int type, int version, ParamEntry* entry, Msg* msg,
vector<Msg*> *ret) {
int procs_id = Cluster::Get()->procs_id();
int src_grp = AddrGrp(msg->src());
int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group();
auto param = entry->shares.at(0);
for (int idx = 0 ; idx < param->num_slices(); idx++) {
int slice_id = param->slice_start() + idx;
int server = slice2server_[slice_id];
int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
Msg* new_msg = nullptr;
if (type == kPut) {
CHECK_GT(entry->num_total, 0);
new_msg = param->GenPutMsg(dst_procs != procs_id, idx);
new_msg->AddFormatFrame("i", entry->num_total);
} else if (type == kGet) {
new_msg = param->GenGetMsg(dst_procs != procs_id, idx);
} else if (type == kUpdate) {
new_msg = param->GenUpdateMsg(dst_procs != procs_id, idx);
new_msg->AddFormatFrame("i", entry->num_local);
} else {
LOG(FATAL) << "Wrong type";
}
new_msg->set_trgt(ParamTrgt(param->owner(), slice_id), version);
new_msg->set_src(Addr(src_grp, procs_id, kStub));
new_msg->set_dst(Addr(dst_grp, server, kServer));
ret->push_back(new_msg);
}
}
const vector<Msg*> Stub::HandleGetRequest(ParamEntry* entry, Msg** msg) {
vector<Msg*> ret;
int version = (*msg)->trgt_version();
if (version > entry->next_version) {
entry->next_version = version;
GenMsgs(kGet, version, entry, *msg, &ret);
}
DeleteMsg(msg);
return ret;
}
const vector<Msg*> Stub::HandleUpdateRequest(ParamEntry *entry, Msg** msg) {
vector<Msg*> ret;
entry->num_update++;
if (entry->num_update >= entry->num_local) {
// average local gradient
if (entry->num_local > 1) {
auto it = entry->shares.begin();
auto sum = it;
for (++it; it != entry->shares.end(); it++) {
AXPY(1.0f, (*it)->grad(), (*sum)->mutable_grad());
}
}
int step = (*msg)->trgt_version();
GenMsgs(kUpdate, step, entry, *msg, &ret);
entry->num_update = 0;
}
DeleteMsg(msg);
return ret;
}
const vector<Msg*> Stub::HandlePutRequest(ParamEntry* entry, Msg** msg) {
vector<Msg*> ret;
int version = (*msg)->trgt_version();
GenMsgs(kPut, version, entry, *msg, &ret);
DeleteMsg(msg);
return ret;
}
void Stub::HandleGetResponse(ParamEntry* entry, Msg** msg) {
int version = (*msg)->trgt_version();
int sliceid = SliceID((*msg)->trgt_val());
auto param = entry->shares.at(0);
if (param->ParseGetResponseMsg(*msg, sliceid-param->slice_start()))
for (auto *p : entry->shares)
p->set_version(version);
DeleteMsg(msg);
}
void Stub::HandleUpdateResponse(ParamEntry* entry, Msg** msg) {
int version = (*msg)->trgt_version();
int sliceid = SliceID((*msg)->trgt_val());
auto param = entry->shares.at(0);
if (param->ParseUpdateResponseMsg(*msg, sliceid-param->slice_start()))
for (auto *p : entry->shares)
p->set_version(version);
DeleteMsg(msg);
}
} // namespace singa