blob: fb5a33a8c72aa187b2d70c18aca420cca732927b [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
#ifndef SINGA_SINGA_DRIVER_H_
#define SINGA_SINGA_DRIVER_H_
#include <vector>
#include "singa/proto/job.pb.h"
#include "singa/proto/singa.pb.h"
#include "singa/utils/factory.h"
#include "singa/utils/param.h"
#include "singa/utils/singleton.h"
#include "singa/utils/updater.h"
#include "singa/neuralnet/layer.h"
#include "singa/worker.h"
#include "singa/server.h"
namespace singa {
using std::vector;
class Driver {
public:
/**
* Init SINGA
* - init glog
* - parse job id and job conf from cmd line
* - register built-in layer, worker, updater, param subclasses.
*
* May be used for MPI init if it is used for message passing.
*/
void Init(int argc, char** argv);
/**
* Init SINGA LOG
* Used for python binding. Users can also directly call it as a C++ API.
* - init glog with given parameters
*
*/
void InitLog(char *arg);
/**
* Update job configuration and call Train(const JobProto&) to start the
* training.
*
* It sets up the logging path and checkpoing files (if resume), and checks
* the existence of the workspace folder .
*
* @param[in] resume if true resume the training from the latest checkpoint
* files.
* @param[in] job_conf job configuration.
*/
void Train(bool resume, const JobProto& job_conf);
/**
* Used for python binding. Users can also directly call it as a C++ API.
*
* It completes the functions as defined above but accept serialized string
* parameters.
*
* @param[in] resume if true resume the training from the latest checkpoint
* files.
* @param[in] str serialized string recorded job configuration.
*/
void Train(bool resume, const std::string str);
/**
* Create workers and servers to conduct the training.
*
* @param[in] job_conf job configuration with all necessary fields set (e.g.,
* by Train(bool, const JobProto&).
*/
void Train(const JobProto& job_conf);
/**
* Test the pre-trained model by loading parameters from checkpoint files.
*
* It can be used for both computing accuracy of test data, and extracting
* features (predicting label) of new data.
* @param[in] job_conf job configuration, which should include the checkpoint
* files and test settings (e.g., test steps). To extract features, the output
* layers should be added.
*/
void Test(const JobProto& job_conf);
/**
* Used for python binding. Users can also directly call it as a C++ API.
*
* It completes the functions as defined above but accept serialized string
* parameters.
*
* @param[in] str serialized string recorded job configuration.
*/
void Test(const std::string str);
/**
* Setting the checkpoint field of the job configuration to resume training.
*
* The checkpoint folder will be searched to get the files for the latest
* checkpoint, which will be added into the checkpoint field. The workers
* would then load the values of params from the checkpoint files.
*
* @param job_conf job configuration
*/
void SetupForResume(JobProto* job_conf);
/**
* Create server instances.
*
* @param[in] job_conf job configuration.
* @param[in] net training neural network.
* @return server instances
*/
const vector<Server*> CreateServers(const JobProto& job_conf, NeuralNet* net);
/**
* Create workers instances.
* @param[in] job_conf job configuration.
* @param[in] net training neural network.
* @return worker instances
*/
const vector<Worker*> CreateWorkers(const JobProto& job_conf, NeuralNet* net);
/*********** Subclasses registers *************************/
/**
* Register a Layer subclass.
*
* @param type layer type ID. If called to register built-in subclasses,
* it is from LayerType; if called to register user-defined
* subclass, it is a string;
* @return 0 if success; otherwise -1.
*/
template<typename Subclass, typename Type>
int RegisterLayer(const Type& type);
/**
* Register an Updater subclass.
*
* @param type ID of the subclass. If called to register built-in subclasses,
* it is from UpdaterType; if called to register user-defined
* subclass, it is a string;
* @return 0 if success; otherwise -1.
*/
template<typename Subclass, typename Type>
int RegisterUpdater(const Type& type);
/**
* Register a learning rate generator subclasses.
*
* @param type ID of the subclass. If called to register built-in subclasses,
* it is from ChangeMethod; if called to register user-defined
* subclass, it is a string;
* @return 0 if success; otherwise -1.
*/
template<typename Subclass, typename Type>
int RegisterLRGenerator(const Type& type);
/**
* Register a Worker subclass.
*
* @param type ID of the subclass. If called to register built-in subclasses,
* it is from TrainOneBatchAlg; if called to register user-defined
* subclass, it is a string;
* @return 0 if success; otherwise -1.
*/
template<typename Subclass, typename Type>
int RegisterWorker(const Type& type);
/**
* Register a Param subclass.
* @param type ID of the subclass. If called to register built-in subclasses,
* it is from ParamType; if called to register user-defined
* subclass, it is a string;
*
* @return 0 if success; otherwise -1.
*/
template<typename Subclass, typename Type>
int RegisterParam(const Type& type);
/**
* Register ParamGenerator subclasses for initalizing Param objects.
*
* @param type ID of the subclass. If called to register built-in subclasses,
* it is from InitMethod; if called to register user-defined
* subclass, it is a string;
* @return 0 if success; otherwise -1.
*/
template<typename Subclass, typename Type>
int RegisterParamGenerator(const Type& type);
/****************** Access function ********************/
/**
* @return job ID which is generated by zookeeper and passed in by the
* launching script.
*/
inline int job_id() const { return job_id_; }
/**
* @return job conf path which is passed by users at the command line. It
* should at least contains the cluster configuration.
*/
inline JobProto job_conf() const { return job_conf_; }
private:
int job_id_;
JobProto job_conf_;
SingaProto singa_conf_;
};
/************* Implementation of template functions*************************
* Must put the implementation in driver.h file instead of driver.cc.
* Otherwise there would be linking error caused by unknown registration
* functions, becuase these function cannot be generated merely based on its
* declearation in driver.h.
*/
template<typename Subclass, typename Type>
int Driver::RegisterLayer(const Type& type) {
auto factory = Singleton<Factory<singa::Layer>>::Instance();
factory->Register(type, CreateInstance(Subclass, Layer));
return 1;
}
template<typename Subclass, typename Type>
int Driver::RegisterParam(const Type& type) {
auto factory = Singleton<Factory<singa::Param>>::Instance();
factory->Register(type, CreateInstance(Subclass, Param));
return 1;
}
template<typename Subclass, typename Type>
int Driver::RegisterParamGenerator(const Type& type) {
auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance();
factory->Register(type, CreateInstance(Subclass, ParamGenerator));
return 1;
}
template<typename Subclass, typename Type>
int Driver::RegisterUpdater(const Type& type) {
auto factory = Singleton<Factory<singa::Updater>>::Instance();
factory->Register(type, CreateInstance(Subclass, Updater));
return 1;
}
template<typename Subclass, typename Type>
int Driver::RegisterLRGenerator(const Type& type) {
auto factory = Singleton<Factory<singa::LRGenerator>>::Instance();
factory->Register(type, CreateInstance(Subclass, LRGenerator));
return 1;
}
template<typename Subclass, typename Type>
int Driver::RegisterWorker(const Type& type) {
auto factory = Singleton<Factory<singa::Worker>>::Instance();
factory->Register(type, CreateInstance(Subclass, Worker));
return 1;
}
} // namespace singa
#endif // SINGA_SINGA_DRIVER_H_