blob: 5f668c19a388768a435c76e8c89005a7569114b7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file kvstore.h
* \brief Rcpp Parameter Store interface of MXNet
*/
#ifndef MXNET_RCPP_KVSTORE_H_
#define MXNET_RCPP_KVSTORE_H_
#include <Rcpp.h>
#include <mxnet/c_api.h>
#include <string>
#include <vector>
#include <map>
#include "./base.h"
namespace mxnet {
namespace R {
/*!
* \brief MXNet's Parameter store interface.
*/
class KVStore {
public:
/*!
* \brief initialize all the weights
* \param keys The keys of each weight.
* \param weights the weights NDArray list.
*/
void Init(const std::vector<int>& keys, const Rcpp::List& weights);
/*!
* \brief Push the weights to the KVStore.
*
* This operation will do a aggregation first on weight_lists, the push things out.
*
* sum_list[i] = sum(list[i] for list in weight_lists)
* Then push(keys[i], sum_list[i]) for each i.
*
* \param keys list of keys, corresponds to key of each location.
* \param weight_lists List of Rcpp::List.
* \param priority The priority of each key.
*/
void Push(const std::vector<int>& keys,
const Rcpp::List& weight_lists,
const std::vector<int>& priority);
/*!
* \brief Pull the data back.
* This operation will MUTATE the content of out_lists.
*
* \param keys List of keys, corresponds to key of each location.
* \param out_lists List of Rcpp::List
* \param priority The priority of each key.
* \return The result list of pull.
*/
void Pull(const std::vector<int>& keys,
const Rcpp::List& out_lists,
const std::vector<int>& priority);
/*! \return The type of KVStore */
std::string type() const;
/*! \brief Whether to perform update on KVStore */
bool update_on_kvstore() const;
/*! \brief Setup optimizer */
void SetOptimizer(const Rcpp::List& optimizer);
// update function
void Update(int index, const NDArray& grad, NDArray *weight);
/*!
* \brief create a KVStore
* \return the created KVStore
*/
static Rcpp::RObject Create(const char *type);
/*! \brief initialize the R cpp Module */
static void InitRcppModule();
// destructor
~KVStore() {
MX_CALL(MXKVStoreFree(handle_));
}
private:
explicit KVStore(KVStoreHandle handle)
: handle_(handle), optimizer_set_(false) {}
// the internal callback to kvstore. This might return NULL
Rcpp::List CreateState(int index, const NDArray& weight) const;
/*! \brief internal KVStore handle */
KVStoreHandle handle_;
/*! \brief Whether optimizer is setted*/
bool optimizer_set_;
/*! \brief The internal state */
std::map<int, Rcpp::List> states_;
/*! \brief Function to create state */
Rcpp::RObject fcreate_state_;
/*! \brief Function to perform update */
Rcpp::RObject fupdate_;
};
} // namespace R
} // namespace mxnet
#endif // MXNET_RCPP_KVSTORE_H_