blob: b4657f033fd59b0c8b33716a303f9da60c6bfab7 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.horn.bsp;
import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.mortbay.log.Log;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* The trainer that train the {@link SmallLayeredNeuralNetwork} based on BSP
* framework.
*
*/
public final class SmallLayeredNeuralNetworkTrainer
extends
BSP<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> {
private SmallLayeredNeuralNetwork inMemoryModel;
private Configuration conf;
/* Default batch size */
private int batchSize;
/* check the interval between intervals */
private double prevAvgTrainingError;
private double curAvgTrainingError;
private long convergenceCheckInterval;
private long iterations;
private long maxIterations;
private AtomicBoolean isConverge;
private String modelPath;
/**
* Returns true if this worker is master worker.
*
* @param peer
* */
private boolean isMaster(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
return peer.getPeerIndex() == 0;
}
@Override
/**
* If the model path is specified, load the existing from storage location.
*/
public void setup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
// At least one master & slave worker exist.
Preconditions.checkArgument(peer.getNumPeers() >= 2);
if (isMaster(peer)) {
Log.info("Begin to train");
}
this.isConverge = new AtomicBoolean(false);
this.conf = peer.getConfiguration();
this.iterations = 0;
this.modelPath = conf.get("modelPath");
this.maxIterations = conf.getLong("training.max.iterations", 100000);
this.convergenceCheckInterval = conf.getLong("convergence.check.interval",
2000);
this.modelPath = conf.get("modelPath");
this.inMemoryModel = new SmallLayeredNeuralNetwork(modelPath);
this.prevAvgTrainingError = Integer.MAX_VALUE;
this.batchSize = conf.getInt("training.batch.size", 50);
}
@Override
/**
* Write the trained model back to stored location.
*/
public void cleanup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
// write model to modelPath
if (isMaster(peer)) {
try {
Log.info(String.format("End of training, number of iterations: %d.\n",
this.iterations));
Log.info(String.format("Write model back to %s\n",
inMemoryModel.getModelPath()));
this.inMemoryModel.writeModelToFile();
} catch (IOException e) {
e.printStackTrace();
}
}
}
@Override
public void bsp(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException, SyncException, InterruptedException {
while (this.iterations++ < maxIterations) {
// each slave-worker calculate the matrices updates according to local data
if (!isMaster(peer)) {
calculateUpdates(peer);
}
peer.sync();
// master merge the updates model
if (isMaster(peer)) {
mergeUpdates(peer);
}
peer.sync();
if (this.isConverge.get()) {
break;
}
}
}
/**
* Calculate the matrices updates according to local partition of data.
*
* @param peer
* @throws IOException
*/
private void calculateUpdates(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException {
// receive update information from master
if (peer.getNumCurrentMessages() != 0) {
SmallLayeredNeuralNetworkMessage inMessage = peer.getCurrentMessage();
DoubleMatrix[] newWeights = inMessage.getCurMatrices();
DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
this.inMemoryModel.setWeightMatrices(newWeights);
this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
this.isConverge.set(inMessage.isConverge());
// check converge
if (isConverge.get()) {
return;
}
}
DoubleMatrix[] weightUpdates = new DoubleMatrix[this.inMemoryModel.weightMatrixList
.size()];
for (int i = 0; i < weightUpdates.length; ++i) {
int row = this.inMemoryModel.weightMatrixList.get(i).getRowCount();
int col = this.inMemoryModel.weightMatrixList.get(i).getColumnCount();
weightUpdates[i] = new DenseDoubleMatrix(row, col);
}
// continue to train
double avgTrainingError = 0.0;
LongWritable key = new LongWritable();
VectorWritable value = new VectorWritable();
for (int recordsRead = 0; recordsRead < batchSize; ++recordsRead) {
if (!peer.readNext(key, value)) {
peer.reopenInput();
peer.readNext(key, value);
}
DoubleVector trainingInstance = value.getVector();
SmallLayeredNeuralNetwork.matricesAdd(weightUpdates,
this.inMemoryModel.trainByInstance(trainingInstance));
avgTrainingError += this.inMemoryModel.trainingError;
}
avgTrainingError /= batchSize;
// calculate the average of updates
for (int i = 0; i < weightUpdates.length; ++i) {
weightUpdates[i] = weightUpdates[i].divide(batchSize);
}
DoubleMatrix[] prevWeightUpdates = this.inMemoryModel
.getPrevMatricesUpdates();
SmallLayeredNeuralNetworkMessage outMessage = new SmallLayeredNeuralNetworkMessage(
avgTrainingError, false, weightUpdates, prevWeightUpdates);
peer.send(peer.getPeerName(0), outMessage);
}
/**
* Merge the updates according to the updates of the grooms.
*
* @param peer
* @throws IOException
*/
private void mergeUpdates(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException {
int numMessages = peer.getNumCurrentMessages();
boolean isConverge = false;
if (numMessages == 0) { // converges
this.isConverge.set(true);
return;
}
double avgTrainingError = 0;
DoubleMatrix[] matricesUpdates = null;
DoubleMatrix[] prevMatricesUpdates = null;
while (peer.getNumCurrentMessages() > 0) {
SmallLayeredNeuralNetworkMessage message = peer.getCurrentMessage();
if (matricesUpdates == null) {
matricesUpdates = message.getCurMatrices();
prevMatricesUpdates = message.getPrevMatrices();
} else {
SmallLayeredNeuralNetwork.matricesAdd(matricesUpdates,
message.getCurMatrices());
SmallLayeredNeuralNetwork.matricesAdd(prevMatricesUpdates,
message.getPrevMatrices());
}
avgTrainingError += message.getTrainingError();
}
if (numMessages != 1) {
avgTrainingError /= numMessages;
for (int i = 0; i < matricesUpdates.length; ++i) {
matricesUpdates[i] = matricesUpdates[i].divide(numMessages);
prevMatricesUpdates[i] = prevMatricesUpdates[i].divide(numMessages);
}
}
this.inMemoryModel.updateWeightMatrices(matricesUpdates);
this.inMemoryModel.setPrevWeightMatrices(prevMatricesUpdates);
// check convergence
if (iterations % convergenceCheckInterval == 0) {
if (prevAvgTrainingError < curAvgTrainingError) {
// error cannot decrease any more
isConverge = true;
}
// update
prevAvgTrainingError = curAvgTrainingError;
curAvgTrainingError = 0;
}
curAvgTrainingError += avgTrainingError / convergenceCheckInterval;
// broadcast updated weight matrices
for (String peerName : peer.getAllPeerNames()) {
SmallLayeredNeuralNetworkMessage msg = new SmallLayeredNeuralNetworkMessage(
0, isConverge, this.inMemoryModel.getWeightMatrices(),
this.inMemoryModel.getPrevMatricesUpdates());
peer.send(peerName, msg);
}
}
}