blob: 02d4a38c2b10adcceecdeb40599bdd4ad674e3a3 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file mxnet_node.h
* \brief implement mxnet nodes
*/
#ifndef MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
#define MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_
#include <queue>
#include <string>
#include <mutex>
#include <condition_variable>
#include <memory>
#include <functional>
#include <future>
#include <vector>
#include "ps/ps.h"
#include "mxnet/kvstore.h"
namespace mxnet {
namespace kvstore {
static const int kStopServer = -1;
static const int kSyncMode = -2;
/**
* \brief executor runs a function using the thread called \ref Start
*/
class Executor {
public:
/**
* \brief start the executor
*/
void Start() {
std::unique_lock<std::mutex> lk(mu_);
while (true) {
cond_.wait(lk, [this]{return !queue_.empty();});
Block blk = std::move(queue_.front());
queue_.pop();
lk.unlock();
if (blk.f) {
blk.f(); blk.p->set_value();
} else {
blk.p->set_value(); break;
}
lk.lock();
}
}
/**
* \brief function
*/
typedef std::function<void()> Func;
/**
* \brief let the thread called \ref Start to exec a function. threadsafe
*/
void Exec(const Func& func) {
Block blk(func);
auto fut = blk.p->get_future();
{
std::lock_guard<std::mutex> lk(mu_);
queue_.push(std::move(blk));
cond_.notify_one();
}
fut.wait();
}
/**
* \brief stop the thread, threadsafe
*/
void Stop() {
Exec(Func());
}
private:
struct Block {
explicit Block(const Func& func) : f(func), p(std::make_shared<std::promise<void>>()) { }
Func f;
std::shared_ptr<std::promise<void>> p;
};
std::queue<Block> queue_;
std::mutex mu_;
std::condition_variable cond_;
};
class KVStoreDistServer {
public:
KVStoreDistServer() {
using namespace std::placeholders;
ps_server_ = new ps::KVServer<float>(0);
static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
ps_server_->set_request_handle(
std::bind(&KVStoreDistServer::DataHandle, this, _1, _2, _3));
sync_mode_ = false;
}
~KVStoreDistServer() {
delete ps_server_;
}
void set_controller(const KVStore::Controller& controller) {
CHECK(controller);
controller_ = controller;
}
void set_updater(const KVStore::Updater& updater) {
CHECK(updater);
updater_ = updater;
}
/**
* \brief blocked until received the command \a kSyncMode
*/
void Run() {
exec_.Start();
}
private:
void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
if (recved.head == kStopServer) {
exec_.Stop();
} else if (recved.head == kSyncMode) {
sync_mode_ = true;
} else {
// let the main thread to execute ctrl, which is necessary for python
exec_.Exec([this, recved]() {
CHECK(controller_);
controller_(recved.head, recved.body);
});
}
app->Response(recved);
}
void DataHandle(const ps::KVMeta& req_meta,
const ps::KVPairs<real_t>& req_data,
ps::KVServer<real_t>* server) {
// do some check
CHECK_EQ(req_data.keys.size(), (size_t)1);
if (req_meta.push) {
CHECK_EQ(req_data.lens.size(), (size_t)1);
CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]);
}
int key = DecodeKey(req_data.keys[0]);
auto& stored = store_[key];
// there used several WaitToRead, this is because \a recved's memory
// could be deallocated when this function returns. so we need to make sure
// the operators with \a NDArray are actually finished
if (req_meta.push) {
size_t ds[] = {(size_t)req_data.lens[0]};
TShape dshape(ds, ds + 1);
TBlob recv_blob((real_t*)req_data.vals.data(), // NOLINT(*)
dshape, cpu::kDevMask);
NDArray recved = NDArray(recv_blob, 0);
if (stored.is_none()) {
// initialization
stored = NDArray(dshape, Context());
CopyFromTo(recved, &stored, 0);
server->Response(req_meta);
stored.WaitToRead();
} else if (sync_mode_) {
// synced push
auto& merged = merge_buf_[key];
if (merged.array.is_none()) {
merged.array = NDArray(dshape, Context());
}
if (merged.request.size() == 0) {
CopyFromTo(recved, &merged.array, 0);
} else {
merged.array += recved;
}
merged.request.push_back(req_meta);
if (merged.request.size() == (size_t)ps::NumWorkers()) {
// let the main thread to execute updater_, which is necessary for
// python
if (updater_) {
exec_.Exec([this, key, &merged, &stored](){
CHECK(updater_);
updater_(key, merged.array, &stored);
});
} else {
// if no updater, just copy
CopyFromTo(merged.array, &stored);
}
for (const auto& req : merged.request) {
server->Response(req);
}
merged.request.clear();
stored.WaitToRead();
} else {
merged.array.WaitToRead();
}
} else {
// async push
exec_.Exec([this, key, &recved, &stored](){
CHECK(updater_);
updater_(key, recved, &stored);
});
server->Response(req_meta);
stored.WaitToRead();
}
} else {
// pull
ps::KVPairs<real_t> response;
CHECK(!stored.is_none()) << "init " << key << " first";
int len = stored.shape()[0];
response.keys = req_data.keys;
response.lens = {len};
// TODO(mli) try to remove this CopyFrom
response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), len);
server->Response(req_meta, response);
}
}
int DecodeKey(ps::Key key) {
auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()];
return key - kr.begin();
}
/**
* \brief user defined
*/
bool sync_mode_;
KVStore::Controller controller_;
KVStore::Updater updater_;
std::unordered_map<int, NDArray> store_;
struct MergeBuf {
std::vector<ps::KVMeta> request;
NDArray array;
};
std::unordered_map<int, MergeBuf> merge_buf_;
Executor exec_;
ps::KVServer<float>* ps_server_;
};
} // namespace kvstore
} // namespace mxnet
#endif // MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_