blob: 9fe41c51a2c6a4c5b3d643101dbd22f0bf2e4bf6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Copyright (c) 2015 by Contributors
* @file kvstore_dist.h
* @brief distributed implementation based on ps-lite
*/
#ifndef MXNET_KVSTORE_KVSTORE_DIST_H_
#define MXNET_KVSTORE_KVSTORE_DIST_H_
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include "./kvstore_local.h"
#include "mxnet/engine.h"
#include "ps/ps.h"
#include "./kvstore_dist_server.h"
namespace mxnet {
namespace kvstore {
/**
* \brief distributed kvstore
*
* it's the server node's job to control the data consistency among all
* workers. see details on \ref ServerHandle::Start
*/
class KVStoreDist : public KVStoreLocal {
public:
explicit KVStoreDist(bool use_device_comm)
: KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
if (IsWorkerNode()) {
int new_customer_id = GetNewCustomerId();
ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
ps::StartAsync(new_customer_id, "mxnet\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
new_customer_id,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}
bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000);
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}
virtual ~KVStoreDist() {
Engine::Get()->WaitForAll();
customer_id_ = 0;
if (IsWorkerNode()) {
if (barrier_before_exit_) {
Barrier();
if (get_rank() == 0 && ps_worker_->get_customer()->customer_id() == 0) {
// stop the executor at servers
SendCommandToServers(static_cast<int>(CommandType::kStopServer), "");
}
}
ps::Finalize(ps_worker_->get_customer()->customer_id(), barrier_before_exit_);
delete ps_worker_;
}
}
void set_updater(const Updater& updater) override {
CHECK(updater) << "invalid updater";
if (IsServerNode()) {
CHECK_NOTNULL(server_)->set_updater(updater);
} else {
updater_ = updater;
}
}
void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
& kwargs) override {
KVStoreLocal::SetGradientCompression(kwargs);
if (get_rank() == 0) {
SendCommandToServers(static_cast<int>(CommandType::kSetGradientCompression),
gradient_compression_->EncodeParams());
}
}
void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
const std::string& params) override {
if (get_rank() == 0) {
SendCommandToServers(static_cast<int>(CommandType::kSetProfilerParams),
params + std::to_string(static_cast<int>(type)));
}
}
void Barrier() override {
ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup);
}
void SendCommandToServers(int cmd_id,
const std::string& cmd_body) override {
CHECK_NOTNULL(ps_worker_);
ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup));
}
int get_group_size() const override { return ps::NumWorkers(); }
int get_rank() const override { return ps::MyRank(); }
int get_num_dead_node(int node_id, int timeout) const override {
int number = 0;
auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout);
const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id);
std::unordered_set<int> watch_set(watch_nodes.begin(), watch_nodes.end());
for (int r : dead_nodes) {
if (watch_set.find(r) != watch_set.end()) number++;
}
return number;
}
void RunServer(const Controller& controller) override {
CHECK(!IsWorkerNode());
if (IsServerNode()) {
server_ = new KVStoreDistServer();
server_->set_controller(controller);
}
ps::StartAsync(0, "mxnet_server\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(0,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
if (server_) server_->Run();
ps::Finalize(0, true);
if (server_) {
delete server_;
}
server_ = nullptr;
}
private:
static std::atomic<int> customer_id_;
static int GetNewCustomerId() {
return customer_id_++;
}
/**
* \brief struct for ps keys and lens
*/
struct PSKV {
ps::SArray<ps::Key> keys; // n keys
ps::SArray<int> lens; // the length of the i-th value
int size;
};
struct ComprPSKV {
PSKV push;
PSKV pull;
};
/**
* \brief cache all key partitions
*
* `ps_kv_` is used for pushes and pulls without gradient compression
* `compr_ps_kv_` is used for gradient compression. It contains different
* pskv for push and pull because sizes would be different in both cases.
* Note: `ps_kv_[k]` for some key k may not be the same as `compr_ps_kv_[k].pull`
* This is because sharding may cause slightly different divisions when size is
* not perfectly divisible.
*/
std::unordered_map<int, PSKV> ps_kv_;
std::unordered_map<int, ComprPSKV> compr_ps_kv_;
/**
* \brief serialize access to ps_kv_ or push_ps_kv_/pull_ps_kv_ while encoding keys
*/
std::mutex mu_;
void InitImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
CheckUnique(keys);
for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
}
if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 0) {
Push_(keys, values, 0, false);
// wait until the push is finished
for (const int key : keys) {
comm_buf_[key].WaitToWrite();
compr_buf_[key].WaitToWrite();
}
} else {
// do nothing
}
if (!ps::Postoffice::Get()->is_recovery()) {
Barrier();
}
}
void PushImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
Push_(keys, values, priority, true);
}
void PullImpl(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority, bool ignore_sparse) override {
CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False";
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
// use the same array for merging to guarantee that pull always happens
// after the previous push on this key
auto& recv_buf = comm_buf_[key];
const auto storage_type = grouped_vals[i][0]->storage_type();
CHECK_EQ(storage_type, kDefaultStorage)
<< "Expected stype of value to be kDefaultStorage";
if (recv_buf.is_none()) {
// it may happen for the first time a no-rank-0 worker pull the weight.
recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
true, grouped_vals[i][0]->dtype());
}
auto pull_from_servers = [this, key, recv_buf](
RunContext rctx, Engine::CallbackOnComplete cb) {
// convert to ps keys
size_t size = recv_buf.shape().Size();
const int dtype = recv_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ?
EncodeDefaultKey(key, size, num_bytes) :
EncodeCompressedKey(key, size, false, num_bytes);
char* data = static_cast<char*> (recv_buf.data().dptr_);
// false means not to delete data when SArray is deleted
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
// issue pull
RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) ?
RequestType::kCompressedPushPull : RequestType::kDefaultPushPull;
const int cmd = GetCommandType(mode, dtype);
CHECK_NOTNULL(ps_worker_)->ZPull(
pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
};
CHECK_NOTNULL(Engine::Get())->PushAsync(
pull_from_servers,
pinned_ctx_,
{},
{recv_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePull");
comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
}
}
void PullRowSparseImpl(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
int priority = 0) override {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids, false);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
// use the same array for merging to guarantee that pull always happens
// after the previous push on this key
auto& recv_buf = comm_buf_[key];
auto& grouped_val_rowid = grouped_val_rowids[i];
const auto storage_type = grouped_val_rowid[0].first->storage_type();
CHECK_EQ(storage_type, kRowSparseStorage)
<< "expected kRowSparseStorage, but got " << storage_type;
if (recv_buf.is_none()) {
// it may happen for the first time a no-rank-0 worker pull the weight.
recv_buf = NDArray(storage_type, grouped_val_rowid[0].first->shape(),
pinned_ctx_, true, grouped_val_rowid[0].first->dtype());
}
auto &target_val_rowids = grouped_val_rowids[i];
const size_t num_vals = target_val_rowids.size();
for (size_t i = 0; i < num_vals; i++) {
auto &row_id = target_val_rowids[i].second;
target_val_rowids[i].second = Unique(row_id, pinned_ctx_, 0);
}
CHECK_EQ(num_vals, 1) << "RowSparsePull with multiple values is not supported yet";
NDArray& indices = target_val_rowids[0].second;
PullRowSparse_(key, recv_buf, indices, priority);
// The recv_buf contains values pulled from remote server with unique indices.
// Directly broadcast w/o rowids if num_vals == 1
auto get_val = [](const std::pair<NDArray*, NDArray>& p) { return p.first; };
std::vector<NDArray*> grouped_val(grouped_val_rowid.size());
std::transform(grouped_val_rowid.begin(), grouped_val_rowid.end(),
grouped_val.begin(), get_val);
comm_->Broadcast(key, recv_buf, grouped_val, priority);
}
}
void Push_(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority,
bool do_merge) {
// first aggregate the values over keys
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals, false);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
// merge over devices
int key = uniq_keys[i];
const auto& vals = grouped_vals[i];
NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0];
const auto storage_type = merged.storage_type();
auto &comm_buf = comm_buf_[key];
if (merged.ctx().dev_mask() == cpu::kDevMask) {
// Start of a push doesn't guarantee that the previous pushes are completed.
// This shouldn't affect training of networks though because training involves
// a sequence of push, pull, then push. This imposes ordering that the
// second push happens after the first pull, and the pull happens after first push.
comm_buf = merged; // avoid memory copy
} else {
if (comm_buf.is_none()) {
if (storage_type == kDefaultStorage) {
comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
} else {
comm_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype());
}
}
CopyFromTo(merged, &comm_buf);
}
const int dtype = merged.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
// push to servers
if (storage_type == kDefaultStorage) {
if (gradient_compression_->get_type() == CompressionType::kNone) {
PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), num_bytes);
PushDefault(key, comm_buf, pskv, priority);
} else {
CHECK_EQ(dtype, mshadow::kFloat32) << "Gradient compression is only supported for "
<< "float32 type of parameters";
// Note: gradient compression uses `do_merge` as proxy to
// detect whether the push is initialization of a key or not.
// is_active is false when push is initialization of key
bool is_active = do_merge;
PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active, num_bytes);
// Returns push_pskv if active, else pull_pskv
// we want inactive gc to send uncompressed gradients,
// but sharded in the same way as later pushes would when gc becomes active
if (is_active) {
PushCompressed(key, comm_buf, pskv, priority);
} else {
PushDefault(key, comm_buf, pskv, priority);
}
}
} else if (storage_type == kRowSparseStorage) {
CHECK(gradient_compression_->get_type() == CompressionType::kNone)
<< "Gradient compression for row sparse storage type is not supported";
PushRowSparse(key, comm_buf, priority);
} else {
LOG(FATAL) << "unknown storage type";
}
}
}
void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int priority) {
auto &small_buf = compr_buf_[key];
auto &res_buf = residual_[key];
const size_t original_size = comm_buf.shape().Size();
const int dtype = comm_buf.dtype();
// Init the small buffer and residual_ buffer for quantize
if (small_buf.is_none()) {
small_buf = NDArray(mxnet::TShape{pskv.size}, comm_buf.ctx(), false, dtype);
res_buf = NDArray(mxnet::TShape{static_cast<int64_t>(original_size)},
comm_buf.ctx(), false, dtype);
res_buf = 0;
}
gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority);
auto push_to_servers =
[this, key, dtype, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
size_t size = small_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
char* data = static_cast<char *> (small_buf.data().dptr_);
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(RequestType::kCompressedPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); });
};
// acquire locks on both comm_buf and small_buf so that
// pull (which uses comm_buf) for the same key waits till push finishes
Engine::Get()->PushAsync(
push_to_servers,
pinned_ctx_,
{small_buf.var(), comm_buf.var()},
{},
FnProperty::kNormal,
priority,
"KVStoreDistCompressedPush");
}
void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
auto push_to_servers =
[this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
const int dtype = send_buf.dtype();
// convert to ps keys
const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
char* data = static_cast<char *>(send_buf.data().dptr_);
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(
pskv.keys, vals, pskv.lens,
cmd, [cb]() { cb(); });
};
Engine::Get()->PushAsync(
push_to_servers,
pinned_ctx_,
{send_buf.var()},
{},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultPush");
}
// push row sparse gradient
void PushRowSparse(int key, const NDArray &send_buf, int priority) {
using namespace rowsparse;
auto push_to_servers = [this, key, send_buf]
(RunContext rctx, Engine::CallbackOnComplete cb) {
char* data = static_cast<char *>(send_buf.data().dptr_);
const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim());
const int num_bytes = mshadow::mshadow_sizeof(send_buf.dtype());
const int64_t size = num_rows * unit_len;
// convert to ps keys in row sparse format
PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
unit_len, send_buf.shape()[0], num_bytes);
if (this->log_verbose_) {
LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens << " keys: "
<< pskv.keys << " size: " << size;
}
ps::SArray<char> vals(data, size * num_bytes, false);
const int cmd = GetCommandType(RequestType::kRowSparsePushPull, send_buf.dtype());
CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); });
};
Engine::Get()->PushAsync(
push_to_servers,
pinned_ctx_,
{send_buf.var()},
{},
FnProperty::kNormal,
priority,
"KVStoreDistRowSparsePush");
}
// pull row sparse weight into `recv_buf` based on indices given by `indices`
void PullRowSparse_(const int key, const NDArray& recv_buf,
const NDArray& indices, int priority) {
using namespace rowsparse;
auto pull_from_servers = [this, key, recv_buf, indices]
(RunContext rctx, Engine::CallbackOnComplete cb) {
// allocate memory for the buffer
CHECK_EQ(indices.dtype(), mshadow::kInt64);
const TBlob idx_data = indices.data();
const size_t num_rows = idx_data.shape_.Size();
recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
const int dtype = recv_buf.dtype();
char* data = static_cast<char *>(recv_buf.data().dptr_);
const auto offsets = idx_data.dptr<int64_t>();
const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
const int64_t size = num_rows * unit_len;
const int num_bytes = mshadow::mshadow_sizeof(dtype);
// convert to ps keys in row sparse format
PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
unit_len, recv_buf.shape()[0],
num_bytes);
if (this->log_verbose_) {
LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
<< pskv.keys << " size: " << size;
}
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
const int cmd = GetCommandType(RequestType::kRowSparsePushPull, recv_buf.dtype());
// copy indices to recv_buf. this needs to be done before ZPull
// because after pull is done, the callback function returns and locks are released.
// at this point, later functions may access the indices variable while copy happens
mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
idx_data.FlatTo1D<cpu, int64_t>());
CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
cmd,
[vals, cb]() { delete vals; cb(); });
};
CHECK_NOTNULL(Engine::Get())->PushAsync(
pull_from_servers,
pinned_ctx_,
{indices.var()},
{recv_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistRowSparsePull");
}
/**
* \brief check if the keys are all unique
*/
void CheckUnique(const std::vector<int>& keys) {
auto keys_copy = keys;
auto last = std::unique(keys_copy.begin(), keys_copy.end());
CHECK_EQ(static_cast<size_t>(std::distance(keys_copy.begin(), last)),
static_cast<size_t>(keys.size()));
}
/**
* \brief convert to pskv for parameter server
* \param key
* \param num_arr_elems number of elements in the value for key
* \param num_bytes size of each element in number of bytes
* \return PSKV used for both push and pull
*/
inline PSKV& EncodeDefaultKey(const int key, const size_t num_arr_elems,
const int num_bytes) {
mu_.lock();
PSKV& pskv = ps_kv_[key];
mu_.unlock();
size_t pskv_size = num_arr_elems * num_bytes;
if (!pskv.keys.empty()) {
CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size)
<< "The value size cannot be changed " << pskv_size << ". Key is " << key;
} else {
auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
const int num_servers = krs.size();
CHECK_GT(num_servers, 0);
// a simple heuristic for load balance
if (num_arr_elems < bigarray_bound_) {
// send it to a single random picked server
int server = (key * 9973) % num_servers;
ps::Key ps_key = krs[server].begin() + key;
CHECK_LT(ps_key, krs[server].end());
pskv.keys.push_back(ps_key);
const int total_bytes = num_arr_elems * num_bytes;
pskv.lens.push_back(total_bytes);
pskv.size = total_bytes;
} else {
// parition it to all servers
pskv.size = 0;
for (int i = 0; i < num_servers; ++i) {
size_t part_size =
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*(i+1))) -
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*i));
ps::Key ps_key = krs[i].begin() + key;
CHECK_LT(ps_key, krs[i].end());
pskv.keys.push_back(ps_key);
const int total_bytes = part_size * num_bytes;
pskv.lens.push_back(total_bytes);
pskv.size += total_bytes;
}
}
CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size);
}
return pskv;
}
/**
* \brief Convert to PSKV for pushes and pulls when gradient compression is used.
* Divides original array into equal parts for each server.
* Populates both push and pull pskv on first call.
* \param key
* \param num_arr_elems number of elements in the value for key
* \param is_push whether this is push or pull
* \param num_bytes size of each element in number of bytes
* \return PSKV used for both push and pull
*/
inline PSKV& EncodeCompressedKey(const int key, const size_t original_num_elem,
const bool is_push, const int num_bytes) {
auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
const int num_servers = krs.size();
CHECK_GT(num_servers, 0);
// represents size of data to be sent
size_t compr_num_elem = gradient_compression_->GetCompressedSize(original_num_elem);
mu_.lock();
PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull;
mu_.unlock();
if (!pskv.keys.empty()) {
const size_t num_elem = (is_push) ? compr_num_elem : original_num_elem;
CHECK_EQ(static_cast<size_t >(pskv.size), num_elem * num_bytes)
<< "The value size can't be changed. For key " << key;
} else {
// populate both pull and push pskvs
// push pskv has sizes corresponding to compressed data
// pull pskv has decompressed sizes for parts in push_pskv
mu_.lock();
PSKV& pull_pskv = compr_ps_kv_[key].pull;
PSKV& push_pskv = compr_ps_kv_[key].push;
mu_.unlock();
if (original_num_elem < bigarray_bound_) {
// a simple heuristic for load balancing
// send it to a single random picked server
const int server = (key * 9973) % num_servers;
ps::Key ps_key = krs[server].begin() + key;
CHECK_LT(ps_key, krs[server].end());
// meta info
push_pskv.keys.push_back(krs[server].begin() + original_num_elem);
push_pskv.lens.push_back(0);
// data
push_pskv.keys.push_back(ps_key);
pull_pskv.keys.push_back(ps_key);
const int compr_size = compr_num_elem * num_bytes;
const int original_size = original_num_elem * num_bytes;
push_pskv.lens.push_back(compr_size);
pull_pskv.lens.push_back(original_size);
push_pskv.size = compr_size;
pull_pskv.size = original_size;
} else {
// partition it to all servers
push_pskv.size = 0;
pull_pskv.size = 0;
for (int i = 0; i < num_servers; ++i) {
size_t part_compr, part_orig;
if (i == num_servers-1) {
part_compr = compr_num_elem - push_pskv.size;
part_orig = original_num_elem - pull_pskv.size;
} else {
part_compr =
static_cast<size_t> (round(static_cast<double>(compr_num_elem)/num_servers*(i+1))) -
static_cast<size_t> (round(static_cast<double>(compr_num_elem)/num_servers*(i)));
part_orig = part_compr * gradient_compression_->GetCompressionFactor();
}
// meta info
ps::Key ps_key_dummy = krs[i].begin() + part_orig;
CHECK_LT(ps_key_dummy, krs[i].end());
push_pskv.keys.push_back(ps_key_dummy);
push_pskv.lens.push_back(0);
// data
ps::Key ps_key = krs[i].begin() + key;
CHECK_LT(ps_key, krs[i].end());
push_pskv.keys.push_back(ps_key);
pull_pskv.keys.push_back(ps_key);
push_pskv.lens.push_back(part_compr * num_bytes);
pull_pskv.lens.push_back(part_orig * num_bytes);
// num elements need to be inserted below so that for last server,
// there is no round off error
push_pskv.size += part_compr;
pull_pskv.size += part_orig;
}
CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_num_elem);
CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_num_elem);
push_pskv.size *= num_bytes;
pull_pskv.size *= num_bytes;
CHECK_EQ(push_pskv.lens.size(), num_servers * 2);
}
}
return pskv;
}
// Note: this encoding method for row sparse keys doesn't allow cross-layer batching
inline PSKV& EncodeRowSparseKey(const int key, const int64_t num_elem, const int64_t num_rows,
const int64_t *offsets, const size_t unit_len,
const int64_t total_num_rows, const int num_bytes) {
using namespace common;
mu_.lock();
PSKV& pskv = ps_kv_[key];
mu_.unlock();
pskv.keys.clear();
pskv.lens.clear();
// TODO(haibin) cache this information
auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
const int num_servers = krs.size();
CHECK_GT(num_servers, 0);
if (total_num_rows * unit_len >= bigarray_bound_) {
pskv.size = 0;
int64_t start_row = 0;
// parition it to all servers
for (int i = 0; i < num_servers; ++i) {
ps::Key master_key = krs[i].begin() + key;
pskv.keys.push_back(master_key);
pskv.lens.push_back(0);
if (offsets && num_elem > 0) {
// calculate partition ranges
int64_t part_num_rows =
llround(static_cast<double>(total_num_rows) / num_servers * (i + 1)) -
llround(static_cast<double>(total_num_rows) / num_servers * i);
auto end_row = start_row + part_num_rows;
// search for offsets in [start_row, end_row)
auto lb = std::lower_bound(offsets, offsets + num_rows, start_row);
auto ub = std::upper_bound(offsets, offsets + num_rows, end_row - 1);
for (auto offset = lb; offset < ub; offset++) {
ps::Key ps_key = krs[i].begin() + key + (*offset - start_row);
CHECK_LT(ps_key, krs[i].end());
pskv.keys.push_back(ps_key);
const int part_size = unit_len * num_bytes;
pskv.lens.push_back(part_size);
pskv.size += (part_size);
}
start_row = end_row;
}
}
CHECK_EQ(static_cast<size_t>(pskv.size), num_elem * num_bytes);
} else {
// send it to a single random picked server
const int server = (key * 9973) % num_servers;
ps::Key master_key = krs[server].begin() + key;
pskv.keys.push_back(master_key);
pskv.lens.push_back(0);
for (int64_t i = 0; i < num_rows; i++) {
ps::Key ps_key = krs[server].begin() + key + offsets[i];
CHECK_LT(ps_key, krs[server].end());
pskv.keys.push_back(ps_key);
pskv.lens.push_back(unit_len * num_bytes);
}
pskv.size = num_elem * num_bytes;
}
return pskv;
}
/**
* \brief for worker to push and pull data
*/
ps::KVWorker<char>* ps_worker_;
/**
* \brief the server handle
*/
KVStoreDistServer* server_;
/**
* \brief threshold for partition
*/
size_t bigarray_bound_;
/**
* \brief buffer for non-compressed data.
* When gradient compression is active, this is used
* for the data in pull and for original data in push
*/
std::unordered_map<int, NDArray> comm_buf_;
/**
* \brief buffer for compressed data
* Used when gradient compression is active and action
* is push
*/
std::unordered_map<int, NDArray> compr_buf_;
/**
* \brief residual buffer to accumulate quantization error
* during gradient compression
*/
std::unordered_map<int, NDArray> residual_;
bool log_verbose_;
};
} // namespace kvstore
} // namespace mxnet
#endif // MXNET_KVSTORE_KVSTORE_DIST_H_