blob: e0a656c48b60f22c1a844665b9f5bad9d5c12a07 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SINGA_MODEL_UPDATER_H_
#define SINGA_MODEL_UPDATER_H_
#include "singa/model/optimizer.h"
#include "singa/core/device.h"
#include "singa/core/tensor.h"
#include "singa/utils/logging.h"
#include <memory>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <string>
#include <utility>
#include <unordered_map>
#include <atomic>
namespace singa {
/// Basic Updater class just forward all the method function call
/// to the wrapped Optimizer.
class Updater {
public:
explicit Updater(Optimizer* opt) : opt_{opt} {}
virtual ~Updater() {}
/// Forward Setup() to Optimizer.
virtual void Setup(const OptimizerConf& conf);
/// Forward Register() to Optimizer.
virtual void Register(const string& name, const ParamSpec& specs);
/// Forward Apply() to Optimizer.
virtual void Apply(int step, const string& name, Tensor& grad, Tensor& value);
Optimizer* GetOptimizer() { return opt_; }
// No copy allowed.
Updater(const Updater&) = delete;
void operator=(const Updater&) = delete;
protected:
Optimizer* opt_;
};
/// LocalUpdater do gradient aggregation and update gradient calling
/// the wrapped Optimizer on a specific device (i.e., CPU or GPU).
class LocalUpdater : public Updater {
public:
LocalUpdater(int total_num, Optimizer* opt,
std::shared_ptr<Device> dev = defaultDevice)
: Updater(opt), total_num_{total_num}, dev_(dev) {}
virtual ~LocalUpdater() override {}
/// Forward Register() to Optimizer.
virtual void Register(const string& name, const ParamSpec& specs) override;
/// Update parameter value based on given gradient by invoking optimizer
/// algoritim. When tranining net call this function will be blocked until
/// all the partial gradients are aggrageted in a synchronized style training.
virtual void Apply(int step, const string& name, Tensor& grad,
Tensor& value) override;
private:
template <typename T1, typename T2>
struct key_hasher {
size_t operator() (const std::pair<T1, T2>& p) const {
auto h1 = std::hash<T1>{}(p.first);
auto h2 = std::hash<T2>{}(p.second);
return h1 ^ h2;
}
};
int total_num_;
std::shared_ptr<Device> dev_;
std::unordered_map<std::string, std::atomic<int>> dev_index_;
std::unordered_map<std::string, int> to_updater_finished_;
std::unordered_map<std::pair<int, std::string>, Tensor,
key_hasher<int, std::string>> grad_buffer_;
std::unordered_map<std::string, Tensor> sum_, param_buffer_;
std::unordered_map<std::string, std::mutex> mtx_;
std::unordered_map<std::string, std::condition_variable>
to_updater_all_finished_;
};
} // namespace singa
#endif // SINGA_MODEL_UPDATER_H_