blob: d65e4a698ae7e158838d3c71c554a478c96bfe87 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.horn.core;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.math.RandomUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.Constants;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.commons.io.FloatMatrixWritable;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseFloatMatrix;
import org.apache.hama.commons.math.DenseFloatVector;
import org.apache.hama.commons.math.FloatFunction;
import org.apache.hama.commons.math.FloatMatrix;
import org.apache.hama.commons.math.FloatVector;
import org.apache.hama.util.ReflectionUtils;
import org.apache.horn.core.Constants.LearningStyle;
import org.apache.horn.core.Constants.TrainingMethod;
import org.apache.horn.examples.MultiLayerPerceptron.StandardNeuron;
import org.apache.horn.funcs.FunctionFactory;
import org.apache.horn.funcs.IdentityFunction;
import org.apache.horn.funcs.SoftMax;
import org.apache.horn.utils.MathUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
/**
* SmallLayeredNeuralNetwork defines the general operations for derivative
* layered models, include Linear Regression, Logistic Regression, Multilayer
* Perceptron, Autoencoder, and Restricted Boltzmann Machine, etc. For
* SmallLayeredNeuralNetwork, the training can be conducted in parallel, but the
* parameters of the models are assumes to be stored in a single machine.
*
* In general, these models consist of neurons which are aligned in layers.
* Between layers, for any two adjacent layers, the neurons are connected to
* form a bipartite weighted graph.
*
*/
public class RecurrentLayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
private static final Log LOG = LogFactory.getLog(RecurrentLayeredNeuralNetwork.class);
/* Weights between neurons at adjacent layers */
protected List<FloatMatrix> weightMatrixList;
/* Weights between neurons at adjacent layers */
protected List<List<FloatMatrix>> weightMatrixLists;
/* Previous weight updates between neurons at adjacent layers */
protected List<FloatMatrix> prevWeightUpdatesList;
protected List<List<FloatMatrix>> prevWeightUpdatesLists;
/* Different layers can have different squashing function */
protected List<FloatFunction> squashingFunctionList;
protected List<Class<? extends Neuron<?>>> neuronClassList;
/* Record the recurrent layer */
protected List<Boolean> recurrentLayerList;
/* Recurrent step size */
protected int recurrentStepSize;
protected int finalLayerIdx;
private List<Neuron<?>[]> neurons;
private List<List<Neuron<?>[]>> neuronLists;
private float dropRate;
private long iterations;
private int numOutCells;
public RecurrentLayeredNeuralNetwork() {
this.layerSizeList = Lists.newArrayList();
this.weightMatrixList = Lists.newArrayList();
this.prevWeightUpdatesList = Lists.newArrayList();
this.squashingFunctionList = Lists.newArrayList();
this.neuronClassList = Lists.newArrayList();
this.weightMatrixLists = Lists.newArrayList();
this.prevWeightUpdatesLists = Lists.newArrayList();
this.neuronLists = Lists.newArrayList();
this.recurrentLayerList = Lists.newArrayList();
}
public RecurrentLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
super(conf, modelPath);
initializeNeurons(false);
initializeWeightMatrixLists();
}
public RecurrentLayeredNeuralNetwork(HamaConfiguration conf, String modelPath,
boolean isTraining) {
super(conf, modelPath);
initializeNeurons(isTraining);
initializeWeightMatrixLists();
}
/**
* initialize neuron objects
* @param isTraining
*/
private void initializeNeurons(boolean isTraining) {
this.neuronLists = Lists.newArrayListWithExpectedSize(recurrentStepSize);
for (int stepIdx = 0; stepIdx < this.recurrentStepSize; stepIdx++) {
neurons = new ArrayList<Neuron<?>[]>();
int expectedNeuronsSize = this.layerSizeList.size();
if (stepIdx < this.recurrentStepSize - this.numOutCells) {
expectedNeuronsSize--;
}
for (int neuronLayerIdx = 0; neuronLayerIdx < expectedNeuronsSize; neuronLayerIdx++) {
int numOfNeurons = layerSizeList.get(neuronLayerIdx);
// if not final layer and next layer is recurrent
if (stepIdx > 0 && neuronLayerIdx < layerSizeList.size() - 1
&& this.recurrentLayerList.get(neuronLayerIdx+1)) {
numOfNeurons = numOfNeurons + layerSizeList.get(neuronLayerIdx+1) - 1;
}
Class<? extends Neuron<?>> neuronClass;
if (neuronLayerIdx == 0)
neuronClass = StandardNeuron.class; // actually doesn't needed
else
neuronClass = neuronClassList.get(neuronLayerIdx - 1);
Neuron<?>[] tmp = new Neuron[numOfNeurons];
for (int neuronIdx = 0; neuronIdx < numOfNeurons; neuronIdx++) {
Neuron<?> n = newNeuronInstance(neuronClass);
if (n instanceof RecurrentDropoutNeuron)
((RecurrentDropoutNeuron) n).setDropRate(dropRate);
if (neuronLayerIdx > 0 && neuronIdx < layerSizeList.get(neuronLayerIdx))
n.setSquashingFunction(squashingFunctionList.get(neuronLayerIdx - 1));
else
n.setSquashingFunction(new IdentityFunction());
n.setLayerIndex(neuronLayerIdx);
n.setNeuronID(neuronIdx);
n.setLearningRate(this.learningRate);
n.setMomentumWeight(this.momentumWeight);
n.setTraining(isTraining);
tmp[neuronIdx] = n;
}
neurons.add(tmp);
}
this.neuronLists.add(neurons);
}
}
/**
* Initialize WeightMatrixLists
*/
public void initializeWeightMatrixLists() {
this.numOutCells = (numOutCells == 0 ? this.recurrentStepSize:numOutCells);
this.weightMatrixLists.clear();
this.weightMatrixLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize);
this.prevWeightUpdatesLists.clear();
this.prevWeightUpdatesLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize);
for (int stepIdx = 0; stepIdx < recurrentStepSize - 1; stepIdx++) {
int expectedMatrixListSize = this.layerSizeList.size() - 1;
if (stepIdx < this.recurrentStepSize - this.numOutCells) {
expectedMatrixListSize--;
}
List<FloatMatrix> aWeightMatrixList = Lists.newArrayListWithExpectedSize(
expectedMatrixListSize);
List<FloatMatrix> aPrevWeightUpdatesList = Lists.newArrayListWithExpectedSize(
expectedMatrixListSize);
for (int matrixIdx = 0; matrixIdx < expectedMatrixListSize; matrixIdx++) {
int rows = this.weightMatrixList.get(matrixIdx).getRowCount();
int cols = this.weightMatrixList.get(matrixIdx).getColumnCount();
if ( stepIdx == 0 )
cols = this.layerSizeList.get(matrixIdx);
FloatMatrix weightMatrix = new DenseFloatMatrix(rows, cols);
weightMatrix.applyToElements(new FloatFunction() {
@Override
public float apply(float value) {
return RandomUtils.nextFloat() - 0.5f;
}
@Override
public float applyDerivative(float value) {
throw new UnsupportedOperationException("");
}
});
aWeightMatrixList.add(weightMatrix);
aPrevWeightUpdatesList.add(
new DenseFloatMatrix(
this.prevWeightUpdatesList.get(matrixIdx).getRowCount(),
this.prevWeightUpdatesList.get(matrixIdx).getColumnCount()));
}
this.weightMatrixLists.add(aWeightMatrixList);
this.prevWeightUpdatesLists.add(aPrevWeightUpdatesList);
}
// add matrix of last step
this.weightMatrixLists.add(this.weightMatrixList);
this.prevWeightUpdatesLists.add(this.prevWeightUpdatesList);
this.weightMatrixList = Lists.newArrayList();
this.prevWeightUpdatesList = Lists.newArrayList();
}
@Override
/**
* {@inheritDoc}
*/
public int addLayer(int size, boolean isFinalLayer,
FloatFunction squashingFunction, Class<? extends Neuron<?>> neuronClass) {
return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null, true);
}
public int addLayer(int size, boolean isFinalLayer,
FloatFunction squashingFunction, Class<? extends Neuron<?>> neuronClass, int numOutCells) {
if (isFinalLayer)
this.numOutCells = (numOutCells == 0 ? this.recurrentStepSize:numOutCells);
return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null, false);
}
public int addLayer(int size, boolean isFinalLayer,
FloatFunction squashingFunction, Class<? extends Neuron<?>> neuronClass,
Class<? extends IntermediateOutput> interlayer, boolean isRecurrent) {
Preconditions.checkArgument(size > 0,
"Size of layer must be larger than 0.");
if (!isFinalLayer) {
if (this.layerSizeList.size() == 0) {
this.recurrentLayerList.add(false);
LOG.info("add input layer: " + size + " neurons");
} else {
this.recurrentLayerList.add(isRecurrent);
LOG.info("add hidden layer: " + size + " neurons");
}
size += 1;
} else {
this.recurrentLayerList.add(false);
}
this.layerSizeList.add(size);
int layerIdx = this.layerSizeList.size() - 1;
if (isFinalLayer) {
this.finalLayerIdx = layerIdx;
LOG.info("add output layer: " + size + " neurons");
}
// add weights between current layer and previous layer, and input layer has
// no squashing function
if (layerIdx > 0) {
int sizePrevLayer = this.layerSizeList.get(layerIdx - 1);
// row count equals to size of current size and column count equals to
// size of previous layer
int row = isFinalLayer ? size : size - 1;
// expand matrix for recurrent layer
int col = !(this.recurrentLayerList.get(layerIdx)) ?
sizePrevLayer : sizePrevLayer + this.layerSizeList.get(layerIdx) - 1;
FloatMatrix weightMatrix = new DenseFloatMatrix(row, col);
// initialize weights
weightMatrix.applyToElements(new FloatFunction() {
@Override
public float apply(float value) {
return RandomUtils.nextFloat() - 0.5f;
}
@Override
public float applyDerivative(float value) {
throw new UnsupportedOperationException("");
}
});
this.weightMatrixList.add(weightMatrix);
this.prevWeightUpdatesList.add(new DenseFloatMatrix(row, col));
this.squashingFunctionList.add(squashingFunction);
this.neuronClassList.add(neuronClass);
}
return layerIdx;
}
/**
* Update the weight matrices with given matrices.
*
* @param matrices
*/
public void updateWeightMatrices(FloatMatrix[] matrices) {
int matrixIdx = 0;
for (List<FloatMatrix> aWeightMatrixList: this.weightMatrixLists) {
for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightMatrixList.size(); weightMatrixIdx++) {
FloatMatrix matrix = aWeightMatrixList.get(weightMatrixIdx);
aWeightMatrixList.set(weightMatrixIdx, matrix.add(matrices[matrixIdx++]));
}
}
}
/**
* Set the previous weight matrices.
*
* @param prevUpdates
*/
void setPrevWeightMatrices(FloatMatrix[] prevUpdates) {
int matrixIdx = 0;
for (List<FloatMatrix> aWeightUpdateMatrixList: this.prevWeightUpdatesLists) {
for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightUpdateMatrixList.size();
weightMatrixIdx++) {
aWeightUpdateMatrixList.set(weightMatrixIdx, prevUpdates[matrixIdx++]);
}
}
}
/**
* Add a batch of matrices onto the given destination matrices.
*
* @param destMatrices
* @param sourceMatrices
*/
static void matricesAdd(FloatMatrix[] destMatrices,
FloatMatrix[] sourceMatrices) {
for (int i = 0; i < destMatrices.length; ++i) {
destMatrices[i] = destMatrices[i].add(sourceMatrices[i]);
}
}
/**
* Get all the weight matrices.
*
* @return The matrices in form of matrix array.
*/
FloatMatrix[] getWeightMatrices() {
FloatMatrix[] matrices = new FloatMatrix[this.getSizeOfWeightmatrix()];
int matrixIdx = 0;
for (List<FloatMatrix> aWeightMatrixList: this.weightMatrixLists) {
for (FloatMatrix aWeightMatrix : aWeightMatrixList) {
matrices[matrixIdx++] = aWeightMatrix;
}
}
return matrices;
}
/**
* Set the weight matrices.
*
* @param matrices
*/
public void setWeightMatrices(FloatMatrix[] matrices) {
int matrixIdx = 0;
for (List<FloatMatrix> aWeightMatrixList: this.weightMatrixLists) {
for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightMatrixList.size(); weightMatrixIdx++) {
aWeightMatrixList.set(weightMatrixIdx, matrices[matrixIdx++]);
}
}
}
/**
* Get the previous matrices updates in form of array.
*
* @return The matrices in form of matrix array.
*/
public FloatMatrix[] getPrevMatricesUpdates() {
FloatMatrix[] matrices = new FloatMatrix[this.getSizeOfWeightmatrix()];
int matrixIdx = 0;
for (List<FloatMatrix> aWeightMatrixList: this.prevWeightUpdatesLists) {
for (FloatMatrix aWeightMatrix : aWeightMatrixList) {
matrices[matrixIdx++] = aWeightMatrix;
}
}
return matrices;
}
public void setWeightMatrix(int index, FloatMatrix matrix) {
Preconditions.checkArgument(
0 <= index && index < this.weightMatrixList.size(), String.format(
"index [%d] should be in range[%d, %d].", index, 0,
this.weightMatrixList.size()));
this.weightMatrixList.set(index, matrix);
}
@Override
public void readFields(DataInput input) throws IOException {
super.readFields(input);
this.finalLayerIdx = input.readInt();
this.dropRate = input.readFloat();
// read neuron classes
int neuronClasses = input.readInt();
this.neuronClassList = Lists.newArrayList();
for (int i = 0; i < neuronClasses; ++i) {
try {
Class<? extends Neuron<?>> clazz = (Class<? extends Neuron<?>>) Class
.forName(input.readUTF());
neuronClassList.add(clazz);
} catch (ClassNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
// read squash functions
int squashingFunctionSize = input.readInt();
this.squashingFunctionList = Lists.newArrayList();
for (int i = 0; i < squashingFunctionSize; ++i) {
this.squashingFunctionList.add(FunctionFactory
.createFloatFunction(WritableUtils.readString(input)));
}
this.recurrentStepSize = input.readInt();
this.numOutCells = input.readInt();
int recurrentLayerListSize = input.readInt();
this.recurrentLayerList = Lists.newArrayList();
for (int i = 0; i < recurrentLayerListSize; i++) {
this.recurrentLayerList.add(input.readBoolean());
}
// read weights and construct matrices of previous updates
int numOfMatrices = input.readInt();
this.weightMatrixLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize);
this.prevWeightUpdatesLists = Lists.newArrayList();
for (int step = 0; step < this.recurrentStepSize; step++) {
this.weightMatrixList = Lists.newArrayList();
this.prevWeightUpdatesList = Lists.newArrayList();
for (int j = 0; j < this.layerSizeList.size() - 2; j++) {
FloatMatrix matrix = FloatMatrixWritable.read(input);
this.weightMatrixList.add(matrix);
this.prevWeightUpdatesList.add(new DenseFloatMatrix(matrix.getRowCount(),
matrix.getColumnCount()));
}
// if the cell has output layer, read from input
if (step >= this.recurrentStepSize - this.numOutCells) {
FloatMatrix matrix = FloatMatrixWritable.read(input);
this.weightMatrixList.add(matrix);
this.prevWeightUpdatesList.add(new DenseFloatMatrix(matrix.getRowCount(),
matrix.getColumnCount()));
}
this.weightMatrixLists.add(this.weightMatrixList);
this.prevWeightUpdatesLists.add(this.prevWeightUpdatesList);
}
}
// }
protected int getSizeOfWeightmatrix() {
return this.recurrentStepSize * (this.layerSizeList.size() - 2) + this.numOutCells;
}
@Override
public void write(DataOutput output) throws IOException {
super.write(output);
output.writeInt(finalLayerIdx);
output.writeFloat(dropRate);
// write neuron classes
output.writeInt(this.neuronClassList.size());
for (Class<? extends Neuron<?>> clazz : this.neuronClassList) {
output.writeUTF(clazz.getName());
}
// write squashing functions
output.writeInt(this.squashingFunctionList.size());
for (FloatFunction aSquashingFunctionList : this.squashingFunctionList) {
WritableUtils.writeString(output,
aSquashingFunctionList.getFunctionName());
}
// write recurrent step size
output.writeInt(this.recurrentStepSize);
// write recurrent step size
output.writeInt(this.numOutCells);
// write recurrent layer list
output.writeInt(this.recurrentLayerList.size());
for (Boolean isReccurentLayer: recurrentLayerList) {
output.writeBoolean(isReccurentLayer);
}
// write weight matrices
output.writeInt(this.getSizeOfWeightmatrix());
for (List<FloatMatrix> aWeightMatrixLists : this.weightMatrixLists) {
for (FloatMatrix aWeightMatrixList : aWeightMatrixLists) {
FloatMatrixWritable.write(aWeightMatrixList, output);
}
}
// DO NOT WRITE WEIGHT UPDATE
}
@Override
public FloatMatrix getWeightsByLayer(int layerIdx) {
return this.weightMatrixList.get(layerIdx);
}
public FloatMatrix getWeightsByLayer(int stepIdx, int layerIdx) {
return this.weightMatrixLists.get(stepIdx).get(layerIdx);
}
/**
* Get the output of the model according to given feature instance.
*/
@Override
public FloatVector getOutput(FloatVector instance) {
Preconditions.checkArgument((this.layerSizeList.get(0) - 1) * this.recurrentStepSize
== instance.getDimension(), String.format(
"The dimension of input instance should be %d.",
this.layerSizeList.get(0) - 1));
// transform the features to another space
FloatVector transformedInstance = this.featureTransformer
.transform(instance);
// add bias feature
FloatVector instanceWithBias = new DenseFloatVector(
transformedInstance.getDimension() + 1);
instanceWithBias.set(0, 0.99999f); // set bias to be a little bit less than
// 1.0
for (int i = 1; i < instanceWithBias.getDimension(); ++i) {
instanceWithBias.set(i, transformedInstance.get(i - 1));
}
// return the output of the last layer
return getOutputInternal(instanceWithBias);
}
public void setDropRateOfInputLayer(float dropRate) {
this.dropRate = dropRate;
}
/**
* Calculate output internally, the intermediate output of each layer will be
* stored.
*
* @param instanceWithBias The instance contains the features.
* @return Cached output of each layer.
*/
public FloatVector getOutputInternal(FloatVector instanceWithBias) {
// sets the output of input layer
Neuron<?>[] inputLayer;
for (int stepIdx = 0; stepIdx < this.weightMatrixLists.size(); stepIdx++) {
inputLayer = neuronLists.get(stepIdx).get(0);
for (int inputNeuronIdx = 0; inputNeuronIdx < this.layerSizeList.get(0); inputNeuronIdx++) {
float m2 = MathUtils.getBinomial(1, dropRate);
if(m2 == 0)
inputLayer[inputNeuronIdx].setDrop(true);
else
inputLayer[inputNeuronIdx].setDrop(false);
inputLayer[inputNeuronIdx].setOutput(
instanceWithBias.get(stepIdx * this.layerSizeList.get(0) + inputNeuronIdx) * m2);
}
// loop forward as much as recurrent step size
this.weightMatrixList = this.weightMatrixLists.get(stepIdx);
for (int layerIdx = 0; layerIdx < weightMatrixList.size(); ++layerIdx) {
forward(stepIdx, layerIdx);
}
}
// output for each recurrent step
int singleOutputLength =
neuronLists.get(this.recurrentStepSize-1).get(this.finalLayerIdx).length;
FloatVector output = new DenseFloatVector(singleOutputLength * this.numOutCells);
int outputNeuronIdx = 0;
for (int step = this.recurrentStepSize - this.numOutCells;
step < this.recurrentStepSize; step++) {
neurons = neuronLists.get(step);
for (int neuronIdx = 0; neuronIdx < singleOutputLength; neuronIdx++) {
output.set(outputNeuronIdx, neurons.get(this.finalLayerIdx)[neuronIdx].getOutput());
outputNeuronIdx++;
}
}
return output;
}
/**
* @param neuronClass
* @return a new neuron instance
*/
@SuppressWarnings({ "rawtypes" })
public static Neuron newNeuronInstance(Class<? extends Neuron> neuronClass) {
return (Neuron) ReflectionUtils.newInstance(neuronClass);
}
public class InputMessageIterable implements
Iterable<Synapse<FloatWritable, FloatWritable>> {
private int currNeuronID;
private int prevNeuronID;
private int end;
private FloatMatrix weightMat;
private Neuron<?>[] layer;
public InputMessageIterable(int fromLayer, int row) {
this.currNeuronID = row;
this.prevNeuronID = -1;
this.end = weightMatrixList.get(fromLayer).getColumnCount() - 1;
this.weightMat = weightMatrixList.get(fromLayer);
this.layer = neurons.get(fromLayer);
}
@Override
public Iterator<Synapse<FloatWritable, FloatWritable>> iterator() {
return new MessageIterator();
}
private class MessageIterator implements
Iterator<Synapse<FloatWritable, FloatWritable>> {
@Override
public boolean hasNext() {
if (prevNeuronID < end) {
return true;
} else {
return false;
}
}
private FloatWritable i = new FloatWritable();
private FloatWritable w = new FloatWritable();
private Synapse<FloatWritable, FloatWritable> msg = new Synapse<FloatWritable, FloatWritable>();
@Override
public Synapse<FloatWritable, FloatWritable> next() {
prevNeuronID++;
i.set(layer[prevNeuronID].getOutput());
w.set(weightMat.get(currNeuronID, prevNeuronID));
msg.set(prevNeuronID, i, w);
return new Synapse<FloatWritable, FloatWritable>(prevNeuronID, i, w);
}
@Override
public void remove() {
}
}
}
/**
* Forward the calculation for one layer.
*
* @param fromLayerIdx The index of the previous layer.
*/
protected void forward(int stepIdx, int fromLayerIdx) {
neurons = this.neuronLists.get(stepIdx);
int curLayerIdx = fromLayerIdx + 1;
// weight matrix for current layer
FloatMatrix weightMatrix = this.weightMatrixList.get(fromLayerIdx);
FloatFunction squashingFunction = getSquashingFunction(fromLayerIdx);
FloatVector vec = new DenseFloatVector(weightMatrix.getRowCount());
for (int row = 0; row < weightMatrix.getRowCount(); row++) {
Neuron<?> n;
if (curLayerIdx == finalLayerIdx)
n = neurons.get(curLayerIdx)[row];
else
n = neurons.get(curLayerIdx)[row + 1];
try {
Iterable msgs = new InputMessageIterable(fromLayerIdx, row);
((RecurrentDropoutNeuron) n).setRecurrentDelta(0);
n.setIterationNumber(iterations);
n.forward(msgs);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
vec.set(row, n.getOutput());
}
if (squashingFunction.getFunctionName().equalsIgnoreCase(
SoftMax.class.getSimpleName())) {
IntermediateOutput interlayer = (IntermediateOutput) ReflectionUtils
.newInstance(SoftMax.SoftMaxOutputComputer.class);
try {
vec = interlayer.interlayer(vec);
for (int i = 0; i < vec.getDimension(); i++) {
neurons.get(curLayerIdx)[i].setOutput(vec.get(i));
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
// add bias
if (curLayerIdx != finalLayerIdx)
neurons.get(curLayerIdx)[0].setOutput(1);
// copy output to next recurrent layer
if (this.recurrentLayerList.get(curLayerIdx) && stepIdx < this.recurrentStepSize - 1) {
for (int i = 0; i < vec.getDimension(); i++) {
this.neuronLists.get(stepIdx+1).get(
fromLayerIdx)[this.layerSizeList.get(fromLayerIdx)+i].setOutput(vec.get(i));
}
}
}
/**
* Train the model online.
*
* @param trainingInstance
*/
public void trainOnline(FloatVector trainingInstance) {
FloatMatrix[] updateMatrices = this.trainByInstance(trainingInstance);
this.updateWeightMatrices(updateMatrices);
}
@Override
public FloatMatrix[] trainByInstance(FloatVector trainingInstance) {
int inputDimension = (this.layerSizeList.get(0) - 1) * this.recurrentStepSize;
FloatVector transformedVector = this.featureTransformer.transform(
trainingInstance.sliceUnsafe(inputDimension));
int outputDimension;
FloatVector inputInstance = null;
FloatVector labels = null;
if (this.learningStyle == LearningStyle.SUPERVISED) {
outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
// validate training instance
Preconditions.checkArgument(
(inputDimension + outputDimension == trainingInstance.getDimension()
|| inputDimension + outputDimension * recurrentStepSize == trainingInstance.getDimension()),
String
.format(
"The dimension of training instance is %d, but requires %d.",
trainingInstance.getDimension(), inputDimension + outputDimension));
inputInstance = new DenseFloatVector(this.layerSizeList.get(0) * this.recurrentStepSize);
// get the features from the transformed vector
int vecIdx = 0;
for (int i = 0; i < inputInstance.getLength(); ++i) {
if (i % this.layerSizeList.get(0) == 0) {
inputInstance.set(i, 1); // add bias
} else {
inputInstance.set(i, transformedVector.get(vecIdx));
vecIdx++;
}
}
// get the labels from the original training instance
labels = trainingInstance.sliceUnsafe(transformedVector.getDimension(),
trainingInstance.getDimension() - 1);
} else if (this.learningStyle == LearningStyle.UNSUPERVISED) {
// labels are identical to input features
outputDimension = inputDimension;
// validate training instance
Preconditions.checkArgument(inputDimension == trainingInstance
.getDimension(), String.format(
"The dimension of training instance is %d, but requires %d.",
trainingInstance.getDimension(), inputDimension));
inputInstance = new DenseFloatVector(this.layerSizeList.get(0) * this.recurrentStepSize);
// get the features from the transformed vector
int vecIdx = 0;
for (int i = 0; i < inputInstance.getLength(); ++i) {
if (i % this.layerSizeList.get(0) == 0) {
inputInstance.set(i, 1); // add bias
} else {
inputInstance.set(i, transformedVector.get(vecIdx));
vecIdx++;
}
}
// get the labels by copying the transformed vector
labels = transformedVector.deepCopy();
}
FloatVector output = this.getOutputInternal(inputInstance);
// get the training error
calculateTrainingError(labels, output);
if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
FloatMatrix[] updates = this.trainByInstanceGradientDescent(labels);
return updates;
} else {
throw new IllegalArgumentException(
String.format("Training method is not supported."));
}
}
/**
* Train by gradient descent. Get the updated weights using one training
* instance.
*
* @param trainingInstance
* @return The weight update matrices.
*/
private FloatMatrix[] trainByInstanceGradientDescent(FloatVector labels) {
// initialize weight update matrices
DenseFloatMatrix[] weightUpdateMatrices = new DenseFloatMatrix[this.getSizeOfWeightmatrix()];
int matrixIdx = 0;
for (List<FloatMatrix> aWeightMatrixList: this.weightMatrixLists) {
for (FloatMatrix aWeightMatrix : aWeightMatrixList) {
weightUpdateMatrices[matrixIdx++] =
new DenseFloatMatrix(aWeightMatrix.getRowCount(), aWeightMatrix.getColumnCount());
}
}
FloatVector deltaVec = new DenseFloatVector(
this.layerSizeList.get(layerSizeList.size() - 1) * this.numOutCells);
FloatFunction squashingFunction = this.squashingFunctionList
.get(this.squashingFunctionList.size() - 1);
int labelIdx = 0;
// start from last recurrent step to first recurrent step
for (int step=this.recurrentStepSize-this.numOutCells; step < this.recurrentStepSize; step++) {
FloatMatrix lastWeightMatrix = this.weightMatrixLists.get(step)
.get(this.weightMatrixLists.get(step).size() - 1);
int neuronIdx = 0;
for (Neuron<?> aNeurons : this.neuronLists.get(step).get(this.finalLayerIdx)) {
float finalOut = aNeurons.getOutput();
float costFuncDerivative = this.costFunction.applyDerivative(
labels.get(labelIdx), finalOut);
// add regularization
costFuncDerivative += this.regularizationWeight
* lastWeightMatrix.getRowVector(neuronIdx).sum();
if (!squashingFunction.getFunctionName().equalsIgnoreCase(
SoftMax.class.getSimpleName())) {
costFuncDerivative *= squashingFunction.applyDerivative(finalOut);
}
aNeurons.backpropagate(costFuncDerivative);
deltaVec.set(labelIdx, costFuncDerivative);
neuronIdx++;
labelIdx++;
}
}
// start from last recurrent step to first recurrent step
boolean skipLastLayer = false;
int weightMatrixIdx = weightUpdateMatrices.length - 1;
for (int step = this.recurrentStepSize - 1 ; step >= 0; --step) {
this.weightMatrixList = this.weightMatrixLists.get(step);
this.prevWeightUpdatesList = this.prevWeightUpdatesLists.get(step);
this.neurons = this.neuronLists.get(step);
if (step < this.recurrentStepSize - this.numOutCells)
skipLastLayer = true;
// start from previous layer of output layer
for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
if (skipLastLayer) {
skipLastLayer = false; continue;
}
backpropagate(step, layer, weightUpdateMatrices[weightMatrixIdx--]);
}
}
// TODO eliminate non-output cells from weightUpdateLists
this.setPrevWeightMatrices(weightUpdateMatrices);
return weightUpdateMatrices;
}
public class ErrorMessageIterable implements
Iterable<Synapse<FloatWritable, FloatWritable>> {
private int row;
private int neuronID;
private int end;
private FloatMatrix weightMat;
private FloatMatrix prevWeightMat;
private float[] nextLayerDelta;
public ErrorMessageIterable(int recurrentStepIdx, int curLayerIdx, int row) {
this.row = row;
this.neuronID = -1;
this.weightMat = weightMatrixLists.get(recurrentStepIdx).get(curLayerIdx);
this.end = weightMat.getRowCount() - 1;
this.prevWeightMat = prevWeightUpdatesLists.get(recurrentStepIdx).get(curLayerIdx);
Neuron<?>[] nextLayer = neuronLists.get(recurrentStepIdx).get(curLayerIdx + 1);
nextLayerDelta = new float[weightMat.getRowCount()];
for(int i = 0; i <= end; ++i) {
if (curLayerIdx + 1 == finalLayerIdx) {
nextLayerDelta[i] = nextLayer[i].getDelta();
} else {
nextLayerDelta[i] = nextLayer[i + 1].getDelta();
}
}
}
@Override
public Iterator<Synapse<FloatWritable, FloatWritable>> iterator() {
return new MessageIterator();
}
private class MessageIterator implements
Iterator<Synapse<FloatWritable, FloatWritable>> {
@Override
public boolean hasNext() {
if (neuronID < end) {
return true;
} else {
return false;
}
}
private FloatWritable d = new FloatWritable();
private FloatWritable w = new FloatWritable();
private FloatWritable p = new FloatWritable();
private Synapse<FloatWritable, FloatWritable> msg = new Synapse<FloatWritable, FloatWritable>();
@Override
public Synapse<FloatWritable, FloatWritable> next() {
neuronID++;
d.set(nextLayerDelta[neuronID]);
w.set(weightMat.get(neuronID, row));
p.set(prevWeightMat.get(neuronID, row));
msg.set(neuronID, d, w, p);
return msg;
}
@Override
public void remove() {
}
}
}
/**
* Back-propagate the errors to from next layer to current layer. The weight
* updated information will be stored in the weightUpdateMatrices, and the
* delta of the prevLayer would be returned.
*
* @param layer Index of current layer.
*/
private void backpropagate(int recurrentStepIdx, int curLayerIdx,
DenseFloatMatrix weightUpdateMatrix) {
// get layer related information
int x = this.weightMatrixList.get(curLayerIdx).getColumnCount();
int y = this.weightMatrixList.get(curLayerIdx).getRowCount();
FloatVector deltaVector = new DenseFloatVector(x);
Neuron<?>[] ns = this.neuronLists.get(recurrentStepIdx).get(curLayerIdx);
for (int row = 0; row < x; ++row) {
Neuron<?> n = ns[row];
n.setWeightVector(y);
try {
Iterable msgs = new ErrorMessageIterable(recurrentStepIdx, curLayerIdx, row);
n.backward(msgs);
if (row >= layerSizeList.get(curLayerIdx) && recurrentStepIdx > 0
&& recurrentLayerList.get(curLayerIdx+1)) {
Neuron<?> recurrentNeuron = neuronLists.get(recurrentStepIdx-1).get(curLayerIdx+1)
[row-layerSizeList.get(curLayerIdx)+1];
recurrentNeuron.backpropagate(n.getDelta());
}
} catch (IOException e) {
e.printStackTrace();
}
// update weights
weightUpdateMatrix.setColumn(row, n.getWeights());
deltaVector.set(row, n.getDelta());
}
}
@Override
protected BSPJob trainInternal(HamaConfiguration conf) throws IOException,
InterruptedException, ClassNotFoundException {
this.conf = conf;
this.fs = FileSystem.get(conf);
String modelPath = conf.get("model.path");
if (modelPath != null) {
this.modelPath = modelPath;
}
// modelPath must be set before training
if (this.modelPath == null) {
throw new IllegalArgumentException(
"Please specify the modelPath for model, "
+ "either through setModelPath() or add 'modelPath' to the training parameters.");
}
this.setRecurrentStepSize(conf.getInt("training.recurrent.step.size", 1));
this.initializeWeightMatrixLists();
this.writeModelToFile();
// create job
BSPJob job = new BSPJob(conf, RecurrentLayeredNeuralNetworkTrainer.class);
job.setJobName("Neural Network training");
job.setJarByClass(RecurrentLayeredNeuralNetworkTrainer.class);
job.setBspClass(RecurrentLayeredNeuralNetworkTrainer.class);
job.getConfiguration().setInt(Constants.ADDITIONAL_BSP_TASKS, 1);
job.setBoolean("training.mode", true);
job.setInputPath(new Path(conf.get("training.input.path")));
job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
job.setInputKeyClass(LongWritable.class);
job.setInputValueClass(VectorWritable.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(NullWritable.class);
job.setOutputFormat(org.apache.hama.bsp.NullOutputFormat.class);
return job;
}
@Override
protected void calculateTrainingError(FloatVector labels, FloatVector output) {
FloatVector errors = labels.deepCopy().applyToElements(output,
this.costFunction);
this.trainingError = errors.sum();
}
/**
* Get the squashing function of a specified layer.
*
* @param idx
* @return a new vector with the result of the operation.
*/
public FloatFunction getSquashingFunction(int idx) {
return this.squashingFunctionList.get(idx);
}
public void setIterationNumber(long iterations) {
this.iterations = iterations;
}
public void setRecurrentStepSize(int recurrentStepSize) {
this.recurrentStepSize = recurrentStepSize;
}
}