blob: 8b08136b8a6a12f18051870b219430a56f24c132 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hama.ml.perception;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.ml.ann.NeuralNetworkTrainer;
/**
* The perceptron trainer for small scale MLP.
*/
class SmallMLPTrainer extends NeuralNetworkTrainer {
/* used by master only, check whether all slaves finishes reading */
private BitSet statusSet;
private int numTrainingInstanceRead = 0;
/* Once reader reaches the EOF, the training procedure would be terminated */
private boolean terminateTraining = false;
private SmallMultiLayerPerceptron inMemoryPerceptron;
private int[] layerSizeArray;
@Override
protected void extraSetup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> peer) {
// obtain parameters
this.trainingMode = conf.get("training.mode", "minibatch.gradient.descent");
// mini-batch by default
this.batchSize = conf.getInt("training.batch.size", 100);
this.statusSet = new BitSet(peer.getConfiguration().getInt("tasks", 1));
String outputModelPath = conf.get("modelPath");
if (outputModelPath == null || outputModelPath.trim().length() == 0) {
try {
throw new Exception("Please specify output model path.");
} catch (Exception e) {
e.printStackTrace();
}
}
String modelPath = conf.get("existingModelPath");
// build model from scratch
if (modelPath == null || modelPath.trim().length() == 0) {
double learningRate = Double.parseDouble(conf.get("learningRate"));
double regularization = Double.parseDouble(conf.get("regularization"));
double momentum = Double.parseDouble(conf.get("momentum"));
String squashingFunctionName = conf.get("squashingFunctionName");
String costFunctionName = conf.get("costFunctionName");
String[] layerSizeArrayStr = conf.get("layerSizeArray").trim().split(" ");
this.layerSizeArray = new int[layerSizeArrayStr.length];
for (int i = 0; i < this.layerSizeArray.length; ++i) {
this.layerSizeArray[i] = Integer.parseInt(layerSizeArrayStr[i]);
}
this.inMemoryPerceptron = new SmallMultiLayerPerceptron(learningRate,
regularization, momentum, squashingFunctionName, costFunctionName,
layerSizeArray);
LOG.info("Training model from scratch.");
} else { // read model from existing data
this.inMemoryPerceptron = new SmallMultiLayerPerceptron(modelPath);
LOG.info("Training with existing model.");
}
}
@Override
protected void extraCleanup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> peer) {
LOG.info(String.format("Task %d totally read %d records.\n",
peer.getPeerIndex(), this.numTrainingInstanceRead));
// master write learned model to disk
if (peer.getPeerIndex() == 0) {
try {
LOG.info(String.format("Master write learned model to %s\n",
conf.get("modelPath")));
this.inMemoryPerceptron.writeModelToFile(conf.get("modelPath"));
} catch (IOException e) {
System.err.println("Please set a correct model path.");
}
}
}
@Override
public void bsp(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> peer)
throws IOException, SyncException, InterruptedException {
LOG.info("Start training...");
if (trainingMode.equalsIgnoreCase("minibatch.gradient.descent")) {
LOG.info("Training Mode: minibatch.gradient.descent");
trainByMinibatch(peer);
}
LOG.info(String.format("Task %d finished.", peer.getPeerIndex()));
}
/**
* Train the MLP with stochastic gradient descent.
*
* @param peer
* @throws IOException
* @throws SyncException
* @throws InterruptedException
*/
private void trainByMinibatch(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> peer)
throws IOException, SyncException, InterruptedException {
int maxIteration = conf.getInt("training.iteration", 1);
LOG.info("# of Training Iteration: " + maxIteration);
for (int i = 0; i < maxIteration; ++i) {
if (peer.getPeerIndex() == 0) {
LOG.info(String.format("Iteration [%d] begins...", i));
}
peer.reopenInput();
// reset status
if (peer.getPeerIndex() == 0) {
this.statusSet = new BitSet(peer.getConfiguration().getInt("tasks", 1));
}
this.terminateTraining = false;
peer.sync();
while (true) {
// each slate task updates weights according to training data
boolean terminate = updateWeights(peer);
peer.sync();
// master merges the updates
if (peer.getPeerIndex() == 0) {
mergeUpdate(peer);
}
peer.sync();
if (terminate) {
break;
}
}
}
}
/**
* Merge the updates from slaves task.
*
* @param peer
* @throws IOException
*/
private void mergeUpdate(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> peer)
throws IOException {
// initialize the cache
DenseDoubleMatrix[] mergedUpdates = this.getZeroWeightMatrices();
int numOfPartitions = peer.getNumCurrentMessages();
// aggregates the weights update
while (peer.getNumCurrentMessages() > 0) {
SmallMLPMessage message = (SmallMLPMessage) peer.getCurrentMessage();
if (message.isTerminated()) {
this.statusSet.set(message.getOwner());
}
DenseDoubleMatrix[] weightUpdates = message.getWeightUpdatedMatrices();
for (int m = 0; m < mergedUpdates.length; ++m) {
mergedUpdates[m] = (DenseDoubleMatrix) mergedUpdates[m]
.add(weightUpdates[m]);
}
}
if (numOfPartitions != 0) {
// calculate the global mean (the mean of batches from all slave tasks) of
// the weight updates
for (int m = 0; m < mergedUpdates.length; ++m) {
mergedUpdates[m] = (DenseDoubleMatrix) mergedUpdates[m]
.divide(numOfPartitions);
}
// check if all tasks finishes reading data
if (this.statusSet.cardinality() == conf.getInt("tasks", 1)) {
this.terminateTraining = true;
}
// update the weight matrices
this.inMemoryPerceptron.updateWeightMatrices(mergedUpdates);
this.inMemoryPerceptron.setPrevWeightUpdateMatrices(mergedUpdates);
}
// broadcast updated weight matrices
for (String peerName : peer.getAllPeerNames()) {
SmallMLPMessage msg = new SmallMLPMessage(peer.getPeerIndex(),
this.terminateTraining, this.inMemoryPerceptron.getWeightMatrices(),
this.inMemoryPerceptron.getPrevWeightUpdateMatrices());
peer.send(peerName, msg);
}
}
/**
* Train the MLP with training data.
*
* @param peer
* @return Whether terminates.
* @throws IOException
*/
private boolean updateWeights(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, MLPMessage> peer)
throws IOException {
// receive update message sent by master
if (peer.getNumCurrentMessages() > 0) {
SmallMLPMessage message = (SmallMLPMessage) peer.getCurrentMessage();
this.terminateTraining = message.isTerminated();
// each slave renew its weight matrices
this.inMemoryPerceptron.setWeightMatrices(message
.getWeightUpdatedMatrices());
this.inMemoryPerceptron.setPrevWeightUpdateMatrices(message
.getPrevWeightsUpdatedMatrices());
if (this.terminateTraining) {
return true;
}
}
// update weight according to training data
DenseDoubleMatrix[] weightUpdates = this.getZeroWeightMatrices();
int count = 0;
LongWritable recordId = new LongWritable();
VectorWritable trainingInstance = new VectorWritable();
boolean hasMore = false;
while (count++ < this.batchSize) {
hasMore = peer.readNext(recordId, trainingInstance);
try {
DenseDoubleMatrix[] singleTrainingInstanceUpdates = this.inMemoryPerceptron
.trainByInstance(trainingInstance.getVector());
// aggregate the updates
for (int m = 0; m < weightUpdates.length; ++m) {
weightUpdates[m] = (DenseDoubleMatrix) weightUpdates[m]
.add(singleTrainingInstanceUpdates[m]);
}
} catch (Exception e) {
e.printStackTrace();
}
++numTrainingInstanceRead;
if (!hasMore) {
break;
}
}
// calculate the local mean (the mean of the local batch) of weight updates
for (int m = 0; m < weightUpdates.length; ++m) {
weightUpdates[m] = (DenseDoubleMatrix) weightUpdates[m].divide(count);
}
LOG.info(String.format("Task %d has read %d records.", peer.getPeerIndex(),
this.numTrainingInstanceRead));
// send the weight updates to master task
SmallMLPMessage message = new SmallMLPMessage(peer.getPeerIndex(),
!hasMore, weightUpdates);
peer.send(peer.getPeerName(0), message); // send status to master
return !hasMore;
}
/**
* Initialize the weight matrices.
*/
private DenseDoubleMatrix[] getZeroWeightMatrices() {
DenseDoubleMatrix[] weightUpdateCache = new DenseDoubleMatrix[this.layerSizeArray.length - 1];
// initialize weight matrix each layer
for (int i = 0; i < weightUpdateCache.length; ++i) {
weightUpdateCache[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1,
this.layerSizeArray[i + 1]);
}
return weightUpdateCache;
}
/**
* Print out the weights.
*
* @param mat
* @return
*/
protected static String weightsToString(DenseDoubleMatrix[] mat) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < mat.length; ++i) {
sb.append(String.format("Matrix [%d]\n", i));
double[][] values = mat[i].getValues();
for (double[] value : values) {
sb.append(Arrays.toString(value));
sb.append('\n');
}
sb.append('\n');
}
return sb.toString();
}
}