blob: d9effcf82f3cbc9571a8e000933dc8729141f32c [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file kvstore.hpp
* \brief implementation of kvstore
* \author Xin Li
*/
#include <algorithm>
#include <map>
#include <numeric>
#include <string>
#include <vector>
#include "mxnet-cpp/kvstore.h"
#include "mxnet-cpp/optimizer.h"
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_HPP_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_HPP_
namespace mxnet {
namespace cpp {
inline void KVStore::Controller(int head, const char* body, void* controller_handle) {
if (head == 0) {
std::map<std::string, std::string> params;
std::istringstream sin(body);
std::string line;
while (getline(sin, line)) {
size_t n = line.find('=');
params.emplace(line.substr(0, n), line.substr(n+1));
}
std::unique_ptr<Optimizer> opt(OptimizerRegistry::Find(params.at("opt_type")));
params.erase("opt_type");
for (const auto& pair : params) {
opt->SetParam(pair.first, pair.second);
}
get_kvstore()->SetOptimizer(std::move(opt), true);
}
}
inline KVStoreHandle& KVStore::get_handle() {
static KVStoreHandle handle_ = nullptr;
return handle_;
}
inline std::unique_ptr<Optimizer>& KVStore::get_optimizer() {
static std::unique_ptr<Optimizer> optimizer_;
return optimizer_;
}
inline KVStore*& KVStore::get_kvstore() {
static KVStore* kvstore_ = new KVStore;
return kvstore_;
}
inline KVStore::KVStore() {}
inline void KVStore::SetType(const std::string& type) {
CHECK_EQ(MXKVStoreCreate(type.c_str(), &(get_kvstore()->get_handle())), 0);
}
inline void KVStore::RunServer() {
CHECK_NE(GetRole(), "worker");
CHECK_EQ(MXKVStoreRunServer(get_kvstore()->get_handle(), &Controller, 0), 0);
}
inline void KVStore::Init(int key, const NDArray& val) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0);
}
inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), keys.size(), keys.data(),
val_handles.data()), 0);
}
inline void KVStore::Push(int key, const NDArray& val, int priority) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0);
}
inline void KVStore::Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals,
int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), keys.size(), keys.data(),
val_handles.data(), priority), 0);
}
inline void KVStore::Pull(int key, NDArray* out, int priority) {
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0);
}
inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());
std::vector<NDArrayHandle> out_handles(keys.size());
std::transform(outs->cbegin(), outs->cend(), out_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), keys.size(), keys.data(),
out_handles.data(), priority), 0);
}
inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
opt->Update(key, NDArray(local), NDArray(recv));
}
inline void KVStore::SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local) {
if (local) {
get_kvstore()->get_optimizer() = std::move(optimizer);
CHECK_EQ(MXKVStoreSetUpdater(get_kvstore()->get_handle(),
&Updater, get_kvstore()->get_optimizer().get()), 0);
} else {
CHECK_EQ(MXKVStoreSendCommmandToServers(get_kvstore()->get_handle(), 0,
(*optimizer).Serialize().c_str()), 0);
}
}
inline std::string KVStore::GetType() {
const char *type;
CHECK_EQ(MXKVStoreGetType(get_kvstore()->get_handle(), &type), 0);
return type;
}
inline int KVStore::GetRank() {
int rank;
CHECK_EQ(MXKVStoreGetRank(get_kvstore()->get_handle(), &rank), 0);
return rank;
}
inline int KVStore::GetNumWorkers() {
int num_workers;
CHECK_EQ(MXKVStoreGetGroupSize(get_kvstore()->get_handle(), &num_workers), 0);
return num_workers;
}
inline void KVStore::Barrier() {
CHECK_EQ(MXKVStoreBarrier(get_kvstore()->get_handle()), 0);
}
inline std::string KVStore::GetRole() {
int ret;
CHECK_EQ(MXKVStoreIsSchedulerNode(&ret), 0);
if (ret) {
return "scheduler";
}
CHECK_EQ(MXKVStoreIsServerNode(&ret), 0);
if (ret) {
return "server";
}
CHECK_EQ(MXKVStoreIsWorkerNode(&ret), 0);
CHECK(ret);
return "worker";
}
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_KVSTORE_HPP_