blob: bd7b5f8d5d6523fc0292e87e3e925d766ff67242 [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
#include "singa/server.h"
#include <thread>
#include <chrono>
#include "mshadow/tensor.h"
#include "singa/proto/common.pb.h"
#include "singa/utils/param.h"
#include "singa/utils/singleton.h"
#include "singa/utils/factory.h"
#include "singa/utils/cluster.h"
namespace singa {
using namespace mshadow;
using std::vector;
Server::Server(int group_id, int server_id,
const JobProto& job_conf,
const vector<int>& slice2group,
const vector<int>& slice2server) {
grp_id_ = group_id;
id_ = server_id;
updater_ = Updater::Create(job_conf.updater());
slice2group_ = slice2group;
slice2server_ = slice2server;
}
Server::~Server() {
delete updater_;
// free Params (i.e., slices) in server shard
for (auto entry : shard_)
for (auto param : entry.second->shares)
delete param;
}
void Stop(void* running) {
*static_cast<bool *>(running) = false;
}
void Server::Run() {
LOG(ERROR) << "Server (group = " << grp_id_ <<", id = " << id_ << ") start";
auto cluster = Cluster::Get();
if (cluster->nserver_groups()) {
CHECK_GT(slice2group_.size(), 0);
if (cluster->nservers_per_group()) {
CHECK_GT(slice2server_.size(), 0);
}
}
n_updates_.resize(slice2group_.size(), 0);
n_pending_sync_.resize(slice2group_.size(), 0);
last_sync_.resize(slice2group_.size());
// TODO(wangsh): give each dealer a unique id
auto dealer = new Dealer(0);
CHECK(dealer->Connect(kInprocRouterEndpoint));
Msg* ping = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub));
ping->set_type(kConnect);
dealer->Send(&ping);
bool running = true;
CHECK(cluster->runtime()->WatchSGroup(grp_id_, id_, Stop, &running));
Poller poll(dealer);
// start recv loop and process requests
while (running) {
// must use poller here; otherwise Receive() gets stuck after workers stop.
auto* sock = poll.Wait(cluster->poll_time());
if (poll.Terminated()) {
LOG(ERROR) << "Connection broken!";
exit(0);
} else if (sock == nullptr) {
continue;
}
Msg* msg = dealer->Receive();
if (msg == nullptr) break; // interrupted
Msg* response = nullptr;
int type = msg->type();
int slice_id = SliceID(msg->trgt_val());
if (type == kPut) {
response = HandlePut(&msg);
} else if (shard_.find(slice_id) == shard_.end()) {
// TODO(wangsh): buffer the msg instead, and process it after the
// corresponding put request is done
// delay the processing by re-queue the msg. May sleep for a while?
response = msg;
} else {
switch (type) {
case kGet:
response = HandleGet(&msg);
break;
case kUpdate:
for (auto reply : HandleUpdate(&msg))
dealer->Send(&reply);
break;
case kSyncRequest:
response = HandleSyncRequest(&msg);
break;
case kSyncResponse:
HandleSyncResponse(&msg);
break;
default:
LOG(ERROR) << "Unknown message type: " << type;
break;
}
}
if (response != nullptr)
dealer->Send(&response);
}
// send stop msg to stub
Msg* msg = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub));
msg->set_type(kStop);
dealer->Send(&msg);
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
LOG(ERROR) << "Server (group = " << grp_id_ << ", id = " << id_ << ") stops";
delete dealer;
}
Msg* Server::HandlePut(Msg **msg) {
int version = (*msg)->trgt_version();
int slice_id = SliceID((*msg)->trgt_val());
if (shard_.find(slice_id) != shard_.end())
LOG(FATAL) << "Param (" << slice_id << ") is put more than once";
// TODO(wangwei) replace hard coded param type 0
auto param = Singleton<Factory<Param>>::Instance()->Create(0);
auto response = param->HandlePutMsg(msg, true);
// parse num of shares of this param from a worker group
int num_shares = 1;
if ((*msg)->NextFrame())
(*msg)->ParseFormatFrame("i", &num_shares);
DeleteMsg(msg);
shard_[slice_id] = new ParamEntry(num_shares, param);
// must set version after HandlePutMsg which allocates the memory
param->set_version(version);
param->set_last_version(version);
param->set_id(slice_id);
// allocate blob for param sync between groups.
if (slice2group_[slice_id] != grp_id_) {
last_sync_[slice_id].ReshapeLike(param->data());
last_sync_[slice_id].CopyFrom(param->data());
}
LOG(INFO) << "server (group = " << grp_id_ << ", id = " << id_
<<") put slice=" << slice_id << " size=" << param->size();
return response;
}
Msg* Server::HandleGet(Msg **msg) {
int val = (*msg)->trgt_val();
auto param = shard_.at(SliceID(val))->shares.at(0);
// re-queue the request if the param is not updated to the required version
if (param->version() < (*msg)->trgt_version()) {
return *msg;
} else {
// LOG(ERROR) << "get " << slice << " from "<<(*msg)->src_first();
auto reply = param->HandleGetMsg(msg, false);
reply->set_trgt(val, param->version());
return reply;
}
}
const vector<Msg*> Server::HandleUpdate(Msg **msg) {
vector<Msg*> ret;
int sliceid = SliceID((*msg)->trgt_val());
auto entry = shard_.at(sliceid);
buffer_requests_[sliceid].push_back(*msg);
int num_update;
(*msg)->LastFrame();
(*msg)->ParseFormatFrame("i", &num_update);
(*msg)->FirstFrame();
entry->num_update += num_update;
// LOG(ERROR) << "update "<< sliceid << " from " << AddrGrp((*msg)->src())
// << ", " << num_update << " total " << entry->num_total;
// do update until recv gradients from all shares of this param/slice
if (entry->num_update >= entry->num_total) {
CHECK_EQ(entry->num_update, entry->num_total);
auto& request = buffer_requests_.at(sliceid);
int step = (*msg)->trgt_version();
int trgt_val = (*msg)->trgt_val();
auto param = entry->shares.at(0);
// extract and aggregate gradients
param->ParseUpdateMsgs(request);
// DLOG(ERROR) << "update param " << param->id() << " @ step " << step;
updater_->Update(step, param, 1.0f / entry->num_total);
param->set_version(param->version() + 1);
// response to all shares of this param
for (auto response : param->GenUpdateResponseMsgs(&request, false)) {
response->set_trgt(trgt_val, param->version());
ret.push_back(response);
}
entry->num_update = 0;
n_updates_[sliceid]++;
// sync with master group after at least sync_freq local updates
// the last check is to avoid sending msg to stopped servers
// may send the update steps on this server since last sync, i.e.,
// version-last_version
if (slice2group_[sliceid] != grp_id_
&& n_updates_[sliceid] >= Cluster::Get()->sync_freq()
&& n_pending_sync_[sliceid] <= Cluster::Get()->sync_freq()) {
auto shape = Shape1(param->size());
Tensor<cpu, 1> tmp(last_sync_[sliceid].mutable_cpu_data(), shape);
Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
tmp = cur - tmp;
int addr = Addr(slice2group_[sliceid], slice2server_[sliceid], kServer);
Msg* sync = new Msg(Addr(grp_id_, id_, kServer), addr);
sync->set_type(kSyncRequest);
sync->set_trgt(trgt_val, param->version());
sync->AddFrame(tmp.dptr, param->size() * sizeof(float));
Copy(tmp, cur);
ret.push_back(sync);
n_updates_[sliceid] = 0;
n_pending_sync_[sliceid]++;
}
}
// message already pushed to buffer, just need to reset the pointer
*msg = nullptr;
return ret;
}
Msg* Server::HandleSyncRequest(Msg **msg) {
Msg* msgg = *msg;
int slice = SliceID(msgg->trgt_val());
auto param = shard_.at(slice)->shares.at(0);
auto shape = Shape1(param->size());
CHECK_EQ(msgg->FrameSize(), param->size()*sizeof(float));
Tensor<cpu, 1> inc(static_cast<float*>(msgg->FrameData()), shape);
Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
// recv sync msg on the slice I am maintaining
cur += inc;
msgg->SwapAddr();
msgg->set_type(kSyncResponse);
// copy the fresh param value into the response msg
Copy(inc, cur);
return msgg;
}
// recv sync msg on slice mastered by others
void Server::HandleSyncResponse(Msg **msg) {
Msg* msgg = *msg;
int slice = SliceID(msgg->trgt_val());
auto param = shard_.at(slice)->shares.at(0);
auto shape = Shape1(param->size());
Tensor<cpu, 1> prev(last_sync_[param->id()].mutable_cpu_data(), shape);
Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
Tensor<cpu, 1> master(static_cast<float*>(msgg->FrameData()), shape);
cur += master - prev; // cur = master + (cur - prev);
Copy(prev, cur);
DeleteMsg(msg);
n_pending_sync_[slice]--;
}
} // namespace singa