HORN-12: Add neuron-centric model based MLP example
diff --git a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
index f87e771..b0162ab 100644
--- a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
@@ -117,8 +117,9 @@
* is f(x) = x by default.
* @return The layer index, starts with 0.
*/
+ @SuppressWarnings("rawtypes")
public abstract int addLayer(int size, boolean isFinalLayer,
- DoubleFunction squashingFunction);
+ DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass);
/**
* Get the size of a particular layer.
diff --git a/src/main/java/org/apache/horn/core/AutoEncoder.java b/src/main/java/org/apache/horn/core/AutoEncoder.java
index f638245..35a6f66 100644
--- a/src/main/java/org/apache/horn/core/AutoEncoder.java
+++ b/src/main/java/org/apache/horn/core/AutoEncoder.java
@@ -53,11 +53,11 @@
public AutoEncoder(int inputDimensions, int compressedDimensions) {
model = new LayeredNeuralNetwork();
model.addLayer(inputDimensions, false,
- FunctionFactory.createDoubleFunction("Sigmoid"));
+ FunctionFactory.createDoubleFunction("Sigmoid"), null);
model.addLayer(compressedDimensions, false,
- FunctionFactory.createDoubleFunction("Sigmoid"));
+ FunctionFactory.createDoubleFunction("Sigmoid"), null);
model.addLayer(inputDimensions, true,
- FunctionFactory.createDoubleFunction("Sigmoid"));
+ FunctionFactory.createDoubleFunction("Sigmoid"), null);
model
.setLearningStyle(AbstractLayeredNeuralNetwork.LearningStyle.UNSUPERVISED);
model.setCostFunction(FunctionFactory
diff --git a/src/main/java/org/apache/horn/core/HornJob.java b/src/main/java/org/apache/horn/core/HornJob.java
index 82dcad8..c95ae36 100644
--- a/src/main/java/org/apache/horn/core/HornJob.java
+++ b/src/main/java/org/apache/horn/core/HornJob.java
@@ -36,18 +36,28 @@
neuralNetwork = new LayeredNeuralNetwork();
}
- public void inputLayer(int featureDimension, Class<? extends Function> func) {
- addLayer(featureDimension, func);
- }
-
- public void addLayer(int featureDimension, Class<? extends Function> func) {
- neuralNetwork.addLayer(featureDimension, false,
- FunctionFactory.createDoubleFunction(func.getSimpleName()));
+ @SuppressWarnings("rawtypes")
+ public void inputLayer(int featureDimension, Class<? extends Function> func,
+ Class<? extends Neuron> neuronClass) {
+ addLayer(featureDimension, func, neuronClass);
}
- public void outputLayer(int labels, Class<? extends Function> func) {
- neuralNetwork.addLayer(labels, true,
- FunctionFactory.createDoubleFunction(func.getSimpleName()));
+ @SuppressWarnings("rawtypes")
+ public void addLayer(int featureDimension, Class<? extends Function> func,
+ Class<? extends Neuron> neuronClass) {
+ neuralNetwork
+ .addLayer(featureDimension, false,
+ FunctionFactory.createDoubleFunction(func.getSimpleName()),
+ neuronClass);
+ }
+
+ @SuppressWarnings("rawtypes")
+ public void outputLayer(int labels, Class<? extends Function> func,
+ Class<? extends Neuron> neuronClass) {
+ neuralNetwork
+ .addLayer(labels, true,
+ FunctionFactory.createDoubleFunction(func.getSimpleName()),
+ neuronClass);
}
public void setCostFunction(Class<? extends Function> func) {
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
index afccbff..32d6c64 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
@@ -64,10 +64,7 @@
*/
public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
- private static final Log LOG = LogFactory
- .getLog(LayeredNeuralNetwork.class);
-
- public static Class<Neuron<Synapse<DoubleWritable, DoubleWritable>>> neuronClass;
+ private static final Log LOG = LogFactory.getLog(LayeredNeuralNetwork.class);
/* Weights between neurons at adjacent layers */
protected List<DoubleMatrix> weightMatrixList;
@@ -78,6 +75,8 @@
/* Different layers can have different squashing function */
protected List<DoubleFunction> squashingFunctionList;
+ protected List<Class<? extends Neuron>> neuronClassList;
+
protected int finalLayerIdx;
protected double regularizationWeight;
@@ -87,6 +86,7 @@
this.weightMatrixList = Lists.newArrayList();
this.prevWeightUpdatesList = Lists.newArrayList();
this.squashingFunctionList = Lists.newArrayList();
+ this.neuronClassList = Lists.newArrayList();
}
public LayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
@@ -99,7 +99,7 @@
* {@inheritDoc}
*/
public int addLayer(int size, boolean isFinalLayer,
- DoubleFunction squashingFunction) {
+ DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass) {
Preconditions.checkArgument(size > 0,
"Size of layer must be larger than 0.");
if (!isFinalLayer) {
@@ -137,6 +137,7 @@
this.weightMatrixList.add(weightMatrix);
this.prevWeightUpdatesList.add(new DenseDoubleMatrix(row, col));
this.squashingFunctionList.add(squashingFunction);
+ this.neuronClassList.add(neuronClass);
}
return layerIdx;
}
@@ -223,6 +224,20 @@
public void readFields(DataInput input) throws IOException {
super.readFields(input);
+ // read neuron classes
+ int neuronClasses = input.readInt();
+ this.neuronClassList = Lists.newArrayList();
+ for (int i = 0; i < neuronClasses; ++i) {
+ try {
+ Class<? extends Neuron> clazz = (Class<? extends Neuron>) Class
+ .forName(input.readUTF());
+ neuronClassList.add(clazz);
+ } catch (ClassNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
// read squash functions
int squashingFunctionSize = input.readInt();
this.squashingFunctionList = Lists.newArrayList();
@@ -248,6 +263,12 @@
public void write(DataOutput output) throws IOException {
super.write(output);
+ // write neuron classes
+ output.writeInt(this.neuronClassList.size());
+ for (Class<? extends Neuron> clazz : this.neuronClassList) {
+ output.writeUTF(clazz.getName());
+ }
+
// write squashing functions
output.writeInt(this.squashingFunctionList.size());
for (DoubleFunction aSquashingFunctionList : this.squashingFunctionList) {
@@ -319,9 +340,12 @@
}
/**
+ * @param neuronClass
* @return a new neuron instance
*/
- public static Neuron<Synapse<DoubleWritable, DoubleWritable>> newNeuronInstance() {
+ @SuppressWarnings({ "unchecked", "rawtypes" })
+ public static Neuron<Synapse<DoubleWritable, DoubleWritable>> newNeuronInstance(
+ Class<? extends Neuron> neuronClass) {
return (Neuron<Synapse<DoubleWritable, DoubleWritable>>) ReflectionUtils
.newInstance(neuronClass);
}
@@ -333,13 +357,9 @@
* @param intermediateOutput The intermediateOutput of previous layer.
* @return a new vector with the result of the operation.
*/
- @SuppressWarnings("unchecked")
protected DoubleVector forward(int fromLayer, DoubleVector intermediateOutput) {
DoubleMatrix weightMatrix = this.weightMatrixList.get(fromLayer);
- neuronClass = (Class<Neuron<Synapse<DoubleWritable, DoubleWritable>>>) conf
- .getClass("neuron.class", Neuron.class);
-
// TODO use the multithread processing
DoubleVector vec = new DenseDoubleVector(weightMatrix.getRowCount());
for (int row = 0; row < weightMatrix.getRowCount(); row++) {
@@ -350,8 +370,8 @@
new DoubleWritable(weightMatrix.get(row, col))));
}
Iterable<Synapse<DoubleWritable, DoubleWritable>> iterable = msgs;
- Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance();
- n.setup(conf);
+ Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
+ .get(fromLayer));
n.setSquashingFunction(this.squashingFunctionList.get(fromLayer));
try {
n.forward(iterable);
@@ -524,9 +544,12 @@
DoubleVector deltaVector = new DenseDoubleVector(
weightMatrix.getColumnCount());
for (int row = 0; row < weightMatrix.getColumnCount(); ++row) {
- Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance();
+ Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
+ .get(curLayerIdx));
// calls setup method
- n.setup(conf);
+ n.setLearningRate(this.learningRate);
+ n.setMomentumWeight(this.momentumWeight);
+
n.setSquashingFunction(this.squashingFunctionList.get(curLayerIdx));
n.setOutput(curLayerOutput.get(row));
@@ -578,13 +601,10 @@
// create job
BSPJob job = new BSPJob(conf, LayeredNeuralNetworkTrainer.class);
- job.setJobName("Small scale Neural Network training");
+ job.setJobName("Neural Network training");
job.setJarByClass(LayeredNeuralNetworkTrainer.class);
job.setBspClass(LayeredNeuralNetworkTrainer.class);
- job.getConfiguration().setClass("neuron.class", StandardNeuron.class,
- Neuron.class);
-
// additional for parameter server
// TODO at this moment, we use 1 task as a parameter server
// In the future, the number of parameter server should be configurable
diff --git a/src/main/java/org/apache/horn/core/Neuron.java b/src/main/java/org/apache/horn/core/Neuron.java
index 357b42f..4471b45 100644
--- a/src/main/java/org/apache/horn/core/Neuron.java
+++ b/src/main/java/org/apache/horn/core/Neuron.java
@@ -17,23 +17,28 @@
*/
package org.apache.horn.core;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
import org.apache.hadoop.io.Writable;
import org.apache.hama.commons.math.DoubleFunction;
-public abstract class Neuron<M extends Writable> implements NeuronInterface<M> {
+public abstract class Neuron<M extends Writable> implements Writable, NeuronInterface<M> {
double output;
double weight;
double delta;
+
+ double momentumWeight;
+ double learningRate;
+
protected DoubleFunction squashingFunction;
public void feedforward(double sum) {
- // TODO Auto-generated method stub
- // squashing
this.output = sum;
}
public void backpropagate(double gradient) {
- // TODO Auto-generated method stub
this.delta = gradient;
}
@@ -53,7 +58,24 @@
return output;
}
- // ////////* Below methods will communicate with parameter server */
+ public void setMomentumWeight(double momentumWeight) {
+ this.momentumWeight = momentumWeight;
+ }
+
+ public double getMomentumWeight() {
+ return momentumWeight;
+ }
+
+ public void setLearningRate(double learningRate) {
+ this.learningRate = learningRate;
+ }
+
+ public double getLearningRate() {
+ return learningRate;
+ }
+
+ // ////////
+
private int i;
public void push(double weight) {
@@ -79,4 +101,24 @@
this.squashingFunction = squashingFunction;
}
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ output = in.readDouble();
+ weight = in.readDouble();
+ delta = in.readDouble();
+
+ momentumWeight = in.readDouble();
+ learningRate = in.readDouble();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(output);
+ out.writeDouble(weight);
+ out.writeDouble(delta);
+
+ out.writeDouble(momentumWeight);
+ out.writeDouble(learningRate);
+ }
+
}
diff --git a/src/main/java/org/apache/horn/core/NeuronInterface.java b/src/main/java/org/apache/horn/core/NeuronInterface.java
index 5e4c113..ef5a2d3 100644
--- a/src/main/java/org/apache/horn/core/NeuronInterface.java
+++ b/src/main/java/org/apache/horn/core/NeuronInterface.java
@@ -20,15 +20,12 @@
import java.io.IOException;
import org.apache.hadoop.io.Writable;
-import org.apache.hama.HamaConfiguration;
public interface NeuronInterface<M extends Writable> {
- public void setup(HamaConfiguration conf);
-
/**
- * This method is called when the messages are propagated from the lower
- * layer. It can be used to determine if the neuron would activate, or fire.
+ * This method is called when the messages are propagated from the next layer.
+ * It can be used to determine if the neuron would activate, or fire.
*
* @param messages
* @throws IOException
@@ -36,13 +33,13 @@
public void forward(Iterable<M> messages) throws IOException;
/**
- * This method is called when the errors are propagated from the upper layer.
- * It can be used to calculate the error of each neuron and change the
+ * This method is called when the errors are propagated from the previous
+ * layer. It can be used to calculate the error of each neuron and change the
* weights.
*
* @param messages
* @throws IOException
*/
public void backward(Iterable<M> messages) throws IOException;
-
+
}
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
index c3bf180..90c9db4 100644
--- a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -31,14 +31,6 @@
public static class StandardNeuron extends
Neuron<Synapse<DoubleWritable, DoubleWritable>> {
- private double learningRate;
- private double momentum;
-
- @Override
- public void setup(HamaConfiguration conf) {
- this.learningRate = conf.getDouble("mlp.learning.rate", 0.5);
- this.momentum = conf.getDouble("mlp.momentum.weight", 0.2);
- }
@Override
public void forward(
@@ -63,8 +55,8 @@
this.backpropagate(gradient);
// Weight corrections
- double weight = -learningRate * this.getOutput() * m.getDelta()
- + momentum * m.getPrevWeight();
+ double weight = -this.getLearningRate() * this.getOutput()
+ * m.getDelta() + this.getMomentumWeight() * m.getPrevWeight();
this.push(weight);
}
}
@@ -88,9 +80,9 @@
job.setConvergenceCheckInterval(1000);
job.setBatchSize(300);
- job.inputLayer(features, Sigmoid.class);
- job.addLayer(features, Sigmoid.class);
- job.outputLayer(labels, Sigmoid.class);
+ job.inputLayer(features, Sigmoid.class, StandardNeuron.class);
+ job.addLayer(features, Sigmoid.class, StandardNeuron.class);
+ job.outputLayer(labels, Sigmoid.class, StandardNeuron.class);
job.setCostFunction(CrossEntropy.class);
@@ -101,9 +93,12 @@
InterruptedException, ClassNotFoundException {
if (args.length < 9) {
System.out
- .println("Usage: <MODEL_PATH> <INPUT_PATH> <LEARNING_RATE> <MOMEMTUM_WEIGHT> <REGULARIZATION_WEIGHT> <FEATURE_DIMENSION> <LABEL_DIMENSION> <MAX_ITERATION> <NUM_TASKS>");
+ .println("Usage: <MODEL_PATH> <INPUT_PATH> "
+ + "<LEARNING_RATE> <MOMEMTUM_WEIGHT> <REGULARIZATION_WEIGHT> "
+ + "<FEATURE_DIMENSION> <LABEL_DIMENSION> <MAX_ITERATION> <NUM_TASKS>");
System.exit(1);
}
+
HornJob ann = createJob(new HamaConfiguration(), args[0], args[1],
Double.parseDouble(args[2]), Double.parseDouble(args[3]),
Double.parseDouble(args[4]), Integer.parseInt(args[5]),
diff --git a/src/main/java/org/apache/horn/utils/MNISTConverter.java b/src/main/java/org/apache/horn/utils/MNISTConverter.java
index 99742d6..224fc4b 100644
--- a/src/main/java/org/apache/horn/utils/MNISTConverter.java
+++ b/src/main/java/org/apache/horn/utils/MNISTConverter.java
@@ -68,6 +68,7 @@
HamaConfiguration conf = new HamaConfiguration();
FileSystem fs = FileSystem.get(conf);
+ @SuppressWarnings("deprecation")
SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(
output), LongWritable.class, VectorWritable.class);
@@ -81,6 +82,8 @@
new DenseDoubleVector(vals)));
}
+ imagesIn.close();
+ labelsIn.close();
writer.close();
}
}
diff --git a/src/test/java/org/apache/horn/core/TestNeuron.java b/src/test/java/org/apache/horn/core/TestNeuron.java
index f2fe4e1..5f2bb59 100644
--- a/src/test/java/org/apache/horn/core/TestNeuron.java
+++ b/src/test/java/org/apache/horn/core/TestNeuron.java
@@ -24,13 +24,10 @@
import junit.framework.TestCase;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hama.HamaConfiguration;
-import org.apache.horn.core.Neuron;
-import org.apache.horn.core.Synapse;
import org.apache.horn.funcs.Sigmoid;
public class TestNeuron extends TestCase {
- private static double learningRate = 0.1;
+ private static double learningrate = 0.1;
private static double bias = -1;
private static double theta = 0.8;
@@ -38,10 +35,6 @@
Neuron<Synapse<DoubleWritable, DoubleWritable>> {
@Override
- public void setup(HamaConfiguration conf) {
- }
-
- @Override
public void forward(
Iterable<Synapse<DoubleWritable, DoubleWritable>> messages)
throws IOException {
@@ -59,14 +52,16 @@
throws IOException {
for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
// Calculates error gradient for each neuron
- double gradient = new Sigmoid().applyDerivative(this.getOutput()) * (m.getDelta() * m.getWeight());
+ double gradient = new Sigmoid().applyDerivative(this.getOutput())
+ * (m.getDelta() * m.getWeight());
// Propagates to lower layer
backpropagate(gradient);
// Weight corrections
- double weight = learningRate * this.getOutput() * m.getDelta();
- this.push(weight);
+ double weight = learningrate * this.getOutput() * m.getDelta();
+ assertEquals(-0.006688234848481696, weight);
+ // this.push(weight);
}
}
@@ -74,10 +69,10 @@
public void testProp() throws IOException {
List<Synapse<DoubleWritable, DoubleWritable>> x = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
- x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(
- 1.0), new DoubleWritable(0.5)));
- x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(
- 1.0), new DoubleWritable(0.4)));
+ x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(1.0),
+ new DoubleWritable(0.5)));
+ x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(1.0),
+ new DoubleWritable(0.4)));
MyNeuron n = new MyNeuron();
n.forward(x);
@@ -87,7 +82,6 @@
x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(
-0.1274), new DoubleWritable(-1.2)));
n.backward(x);
- assertEquals(-0.006688234848481696, n.getUpdate());
}
}