blob: dc5f7b78624452087fa408be4ebe4a24e57eb8f5 [file] [log] [blame]
/**
* Copyright (c) 2015 by Contributors
* @file kvstore_local.h
* @brief local implementation
*/
#ifndef MXNET_KVSTORE_KVSTORE_LOCAL_H_
#define MXNET_KVSTORE_KVSTORE_LOCAL_H_
#include <mxnet/kvstore.h>
#include <unordered_map>
#include <bitset>
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include "./comm.h"
namespace mxnet {
namespace kvstore {
/**
* \brief store data in local machine
*/
class KVStoreLocal : public KVStore {
public:
/*
* \param use_device_comm
*/
explicit KVStoreLocal(bool use_device_comm) : KVStore() {
if (use_device_comm) {
comm_ = new CommDevice();
} else {
comm_ = new CommCPU();
}
pinned_ctx_ = comm_->pinned_ctx();
}
virtual ~KVStoreLocal() {
delete comm_;
}
void Init(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
for (size_t i = 0; i < keys.size(); ++i) {
CHECK(local_.find(keys[i]) == local_.end())
<< "duplicate init of key " << keys[i];
local_[keys[i]] = values[i].Copy(pinned_ctx_);
comm_->Init(keys[i], values[i].shape(), values[i].dtype());
}
}
void Init(const std::vector<std::string>& str_keys,
const std::vector<NDArray>& values) override {
std::vector<int> keys(str_keys.size());
for (size_t i = 0; i < str_keys.size(); ++i) {
auto &str_key = str_keys[i];
CHECK(str_key_dict_.find(str_key) == str_key_dict_.end())
<< "duplicate init of key " << str_key;
auto key = next_str_key_++;
str_key_dict_[str_key] = key;
keys[i] = key;
}
Init(keys, values);
}
void Push(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
const NDArray& merged = comm_->Reduce(key, grouped_vals[i], priority);
NDArray& local = local_[key];
if (updater_ != nullptr) {
CHECK(!local.is_none()) << "key " << key << " has not been inited";
// if merged is on gpu, we may need copy weight from cpu to gpu
if (merged.ctx().dev_mask() != cpu::kDevMask &&
local.ctx().dev_mask() == cpu::kDevMask) {
local = local.Copy(merged.ctx());
}
updater_(key, merged, &local);
} else {
local = merged;
}
}
}
void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
const NDArray& local = local_[key];
CHECK(!local.is_none()) << "key " << key << " has not been inited";
comm_->Broadcast(key, local, grouped_vals[i], priority);
}
}
void Push(const std::vector<std::string>& str_keys,
const std::vector<NDArray>& values,
int priority) override {
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
Push(keys, values, priority);
}
void Pull(const std::vector<std::string>& str_keys,
const std::vector<NDArray*>& values,
int priority) override {
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
Pull(keys, values, priority);
}
protected:
/**
* \brief group values on keys
*/
template <typename V>
void GroupKVPairs(const std::vector<int>& keys,
const std::vector<V>& values,
std::vector<int>* uniq_keys,
std::vector<std::vector<V> >* grouped_vals) {
CHECK_EQ(keys.size(), values.size());
// TODO(mli) check if already sorted as an optimization
using Idx = std::pair<int, int>;
std::vector<Idx> idx(keys.size());
for (size_t i = 0; i < keys.size(); ++i) {
idx[i].first = keys[i]; idx[i].second = i;
}
std::sort(idx.begin(), idx.end(), [](const Idx& a, const Idx& b) {
return a.first < b.first;
});
int pre_key = idx[0].first - 1;
for (auto i : idx) {
if (i.first != pre_key) {
uniq_keys->push_back(i.first);
grouped_vals->push_back({values[i.second]});
pre_key = i.first;;
} else {
grouped_vals->back().push_back(values[i.second]);
}
}
}
void LookupKeys(const std::vector<std::string>& str_keys,
std::vector<int> *keys) {
for (size_t i = 0; i < str_keys.size(); ++i) {
auto &str_key = str_keys[i];
CHECK(str_key_dict_.find(str_key) != str_key_dict_.end())
<< "key " << str_key << " doesn't exist. Did you init?";
keys->at(i) = str_key_dict_[str_key];
}
}
/// reducer and broadcaster
Comm* comm_;
/// pinned context
Context pinned_ctx_;
/// \brief buffer for storing local values
std::unordered_map<int, NDArray> local_;
/// key mapping for string -> integer
std::unordered_map<std::string, int> str_key_dict_;
/// the next available integer for string->int key mapping
int next_str_key_ = 0;
};
} // namespace kvstore
} // namespace mxnet
#endif // MXNET_KVSTORE_KVSTORE_LOCAL_H_