HORN-12: Add neuron-centric model based MLP example
diff --git a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
index f87e771..b0162ab 100644
--- a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
@@ -117,8 +117,9 @@
    *          is f(x) = x by default.
    * @return The layer index, starts with 0.
    */
+  @SuppressWarnings("rawtypes")
   public abstract int addLayer(int size, boolean isFinalLayer,
-      DoubleFunction squashingFunction);
+      DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass);
 
   /**
    * Get the size of a particular layer.
diff --git a/src/main/java/org/apache/horn/core/AutoEncoder.java b/src/main/java/org/apache/horn/core/AutoEncoder.java
index f638245..35a6f66 100644
--- a/src/main/java/org/apache/horn/core/AutoEncoder.java
+++ b/src/main/java/org/apache/horn/core/AutoEncoder.java
@@ -53,11 +53,11 @@
   public AutoEncoder(int inputDimensions, int compressedDimensions) {
     model = new LayeredNeuralNetwork();
     model.addLayer(inputDimensions, false,
-        FunctionFactory.createDoubleFunction("Sigmoid"));
+        FunctionFactory.createDoubleFunction("Sigmoid"), null);
     model.addLayer(compressedDimensions, false,
-        FunctionFactory.createDoubleFunction("Sigmoid"));
+        FunctionFactory.createDoubleFunction("Sigmoid"), null);
     model.addLayer(inputDimensions, true,
-        FunctionFactory.createDoubleFunction("Sigmoid"));
+        FunctionFactory.createDoubleFunction("Sigmoid"), null);
     model
         .setLearningStyle(AbstractLayeredNeuralNetwork.LearningStyle.UNSUPERVISED);
     model.setCostFunction(FunctionFactory
diff --git a/src/main/java/org/apache/horn/core/HornJob.java b/src/main/java/org/apache/horn/core/HornJob.java
index 82dcad8..c95ae36 100644
--- a/src/main/java/org/apache/horn/core/HornJob.java
+++ b/src/main/java/org/apache/horn/core/HornJob.java
@@ -36,18 +36,28 @@
     neuralNetwork = new LayeredNeuralNetwork();
   }
 
-  public void inputLayer(int featureDimension, Class<? extends Function> func) {
-    addLayer(featureDimension, func);
-  }
-  
-  public void addLayer(int featureDimension, Class<? extends Function> func) {
-    neuralNetwork.addLayer(featureDimension, false,
-        FunctionFactory.createDoubleFunction(func.getSimpleName()));
+  @SuppressWarnings("rawtypes")
+  public void inputLayer(int featureDimension, Class<? extends Function> func,
+      Class<? extends Neuron> neuronClass) {
+    addLayer(featureDimension, func, neuronClass);
   }
 
-  public void outputLayer(int labels, Class<? extends Function> func) {
-    neuralNetwork.addLayer(labels, true,
-        FunctionFactory.createDoubleFunction(func.getSimpleName()));
+  @SuppressWarnings("rawtypes")
+  public void addLayer(int featureDimension, Class<? extends Function> func,
+      Class<? extends Neuron> neuronClass) {
+    neuralNetwork
+        .addLayer(featureDimension, false,
+            FunctionFactory.createDoubleFunction(func.getSimpleName()),
+            neuronClass);
+  }
+
+  @SuppressWarnings("rawtypes")
+  public void outputLayer(int labels, Class<? extends Function> func,
+      Class<? extends Neuron> neuronClass) {
+    neuralNetwork
+        .addLayer(labels, true,
+            FunctionFactory.createDoubleFunction(func.getSimpleName()),
+            neuronClass);
   }
 
   public void setCostFunction(Class<? extends Function> func) {
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
index afccbff..32d6c64 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
@@ -64,10 +64,7 @@
  */
 public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
 
-  private static final Log LOG = LogFactory
-      .getLog(LayeredNeuralNetwork.class);
-
-  public static Class<Neuron<Synapse<DoubleWritable, DoubleWritable>>> neuronClass;
+  private static final Log LOG = LogFactory.getLog(LayeredNeuralNetwork.class);
 
   /* Weights between neurons at adjacent layers */
   protected List<DoubleMatrix> weightMatrixList;
@@ -78,6 +75,8 @@
   /* Different layers can have different squashing function */
   protected List<DoubleFunction> squashingFunctionList;
 
+  protected List<Class<? extends Neuron>> neuronClassList;
+
   protected int finalLayerIdx;
 
   protected double regularizationWeight;
@@ -87,6 +86,7 @@
     this.weightMatrixList = Lists.newArrayList();
     this.prevWeightUpdatesList = Lists.newArrayList();
     this.squashingFunctionList = Lists.newArrayList();
+    this.neuronClassList = Lists.newArrayList();
   }
 
   public LayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
@@ -99,7 +99,7 @@
    * {@inheritDoc}
    */
   public int addLayer(int size, boolean isFinalLayer,
-      DoubleFunction squashingFunction) {
+      DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass) {
     Preconditions.checkArgument(size > 0,
         "Size of layer must be larger than 0.");
     if (!isFinalLayer) {
@@ -137,6 +137,7 @@
       this.weightMatrixList.add(weightMatrix);
       this.prevWeightUpdatesList.add(new DenseDoubleMatrix(row, col));
       this.squashingFunctionList.add(squashingFunction);
+      this.neuronClassList.add(neuronClass);
     }
     return layerIdx;
   }
@@ -223,6 +224,20 @@
   public void readFields(DataInput input) throws IOException {
     super.readFields(input);
 
+    // read neuron classes
+    int neuronClasses = input.readInt();
+    this.neuronClassList = Lists.newArrayList();
+    for (int i = 0; i < neuronClasses; ++i) {
+      try {
+        Class<? extends Neuron> clazz = (Class<? extends Neuron>) Class
+            .forName(input.readUTF());
+        neuronClassList.add(clazz);
+      } catch (ClassNotFoundException e) {
+        // TODO Auto-generated catch block
+        e.printStackTrace();
+      }
+    }
+
     // read squash functions
     int squashingFunctionSize = input.readInt();
     this.squashingFunctionList = Lists.newArrayList();
@@ -248,6 +263,12 @@
   public void write(DataOutput output) throws IOException {
     super.write(output);
 
+    // write neuron classes
+    output.writeInt(this.neuronClassList.size());
+    for (Class<? extends Neuron> clazz : this.neuronClassList) {
+      output.writeUTF(clazz.getName());
+    }
+
     // write squashing functions
     output.writeInt(this.squashingFunctionList.size());
     for (DoubleFunction aSquashingFunctionList : this.squashingFunctionList) {
@@ -319,9 +340,12 @@
   }
 
   /**
+   * @param neuronClass
    * @return a new neuron instance
    */
-  public static Neuron<Synapse<DoubleWritable, DoubleWritable>> newNeuronInstance() {
+  @SuppressWarnings({ "unchecked", "rawtypes" })
+  public static Neuron<Synapse<DoubleWritable, DoubleWritable>> newNeuronInstance(
+      Class<? extends Neuron> neuronClass) {
     return (Neuron<Synapse<DoubleWritable, DoubleWritable>>) ReflectionUtils
         .newInstance(neuronClass);
   }
@@ -333,13 +357,9 @@
    * @param intermediateOutput The intermediateOutput of previous layer.
    * @return a new vector with the result of the operation.
    */
-  @SuppressWarnings("unchecked")
   protected DoubleVector forward(int fromLayer, DoubleVector intermediateOutput) {
     DoubleMatrix weightMatrix = this.weightMatrixList.get(fromLayer);
 
-    neuronClass = (Class<Neuron<Synapse<DoubleWritable, DoubleWritable>>>) conf
-        .getClass("neuron.class", Neuron.class);
-
     // TODO use the multithread processing
     DoubleVector vec = new DenseDoubleVector(weightMatrix.getRowCount());
     for (int row = 0; row < weightMatrix.getRowCount(); row++) {
@@ -350,8 +370,8 @@
             new DoubleWritable(weightMatrix.get(row, col))));
       }
       Iterable<Synapse<DoubleWritable, DoubleWritable>> iterable = msgs;
-      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance();
-      n.setup(conf);
+      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
+          .get(fromLayer));
       n.setSquashingFunction(this.squashingFunctionList.get(fromLayer));
       try {
         n.forward(iterable);
@@ -524,9 +544,12 @@
     DoubleVector deltaVector = new DenseDoubleVector(
         weightMatrix.getColumnCount());
     for (int row = 0; row < weightMatrix.getColumnCount(); ++row) {
-      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance();
+      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
+          .get(curLayerIdx));
       // calls setup method
-      n.setup(conf);
+      n.setLearningRate(this.learningRate);
+      n.setMomentumWeight(this.momentumWeight);
+
       n.setSquashingFunction(this.squashingFunctionList.get(curLayerIdx));
       n.setOutput(curLayerOutput.get(row));
 
@@ -578,13 +601,10 @@
 
     // create job
     BSPJob job = new BSPJob(conf, LayeredNeuralNetworkTrainer.class);
-    job.setJobName("Small scale Neural Network training");
+    job.setJobName("Neural Network training");
     job.setJarByClass(LayeredNeuralNetworkTrainer.class);
     job.setBspClass(LayeredNeuralNetworkTrainer.class);
 
-    job.getConfiguration().setClass("neuron.class", StandardNeuron.class,
-        Neuron.class);
-
     // additional for parameter server
     // TODO at this moment, we use 1 task as a parameter server
     // In the future, the number of parameter server should be configurable
diff --git a/src/main/java/org/apache/horn/core/Neuron.java b/src/main/java/org/apache/horn/core/Neuron.java
index 357b42f..4471b45 100644
--- a/src/main/java/org/apache/horn/core/Neuron.java
+++ b/src/main/java/org/apache/horn/core/Neuron.java
@@ -17,23 +17,28 @@
  */
 package org.apache.horn.core;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
 import org.apache.hadoop.io.Writable;
 import org.apache.hama.commons.math.DoubleFunction;
 
-public abstract class Neuron<M extends Writable> implements NeuronInterface<M> {
+public abstract class Neuron<M extends Writable> implements Writable, NeuronInterface<M> {
   double output;
   double weight;
   double delta;
+
+  double momentumWeight;
+  double learningRate;
+
   protected DoubleFunction squashingFunction;
 
   public void feedforward(double sum) {
-    // TODO Auto-generated method stub
-    // squashing
     this.output = sum;
   }
 
   public void backpropagate(double gradient) {
-    // TODO Auto-generated method stub
     this.delta = gradient;
   }
 
@@ -53,7 +58,24 @@
     return output;
   }
 
-  // ////////* Below methods will communicate with parameter server */
+  public void setMomentumWeight(double momentumWeight) {
+    this.momentumWeight = momentumWeight;
+  }
+
+  public double getMomentumWeight() {
+    return momentumWeight;
+  }
+
+  public void setLearningRate(double learningRate) {
+    this.learningRate = learningRate;
+  }
+
+  public double getLearningRate() {
+    return learningRate;
+  }
+
+  // ////////
+
   private int i;
 
   public void push(double weight) {
@@ -79,4 +101,24 @@
     this.squashingFunction = squashingFunction;
   }
 
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    output = in.readDouble();
+    weight = in.readDouble();
+    delta = in.readDouble();
+
+    momentumWeight = in.readDouble();
+    learningRate = in.readDouble();
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeDouble(output);
+    out.writeDouble(weight);
+    out.writeDouble(delta);
+    
+    out.writeDouble(momentumWeight);
+    out.writeDouble(learningRate);
+  }
+
 }
diff --git a/src/main/java/org/apache/horn/core/NeuronInterface.java b/src/main/java/org/apache/horn/core/NeuronInterface.java
index 5e4c113..ef5a2d3 100644
--- a/src/main/java/org/apache/horn/core/NeuronInterface.java
+++ b/src/main/java/org/apache/horn/core/NeuronInterface.java
@@ -20,15 +20,12 @@
 import java.io.IOException;
 
 import org.apache.hadoop.io.Writable;
-import org.apache.hama.HamaConfiguration;
 
 public interface NeuronInterface<M extends Writable> {
 
-  public void setup(HamaConfiguration conf);
-  
   /**
-   * This method is called when the messages are propagated from the lower
-   * layer. It can be used to determine if the neuron would activate, or fire.
+   * This method is called when the messages are propagated from the next layer.
+   * It can be used to determine if the neuron would activate, or fire.
    * 
    * @param messages
    * @throws IOException
@@ -36,13 +33,13 @@
   public void forward(Iterable<M> messages) throws IOException;
 
   /**
-   * This method is called when the errors are propagated from the upper layer.
-   * It can be used to calculate the error of each neuron and change the
+   * This method is called when the errors are propagated from the previous
+   * layer. It can be used to calculate the error of each neuron and change the
    * weights.
    * 
    * @param messages
    * @throws IOException
    */
   public void backward(Iterable<M> messages) throws IOException;
-  
+
 }
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
index c3bf180..90c9db4 100644
--- a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -31,14 +31,6 @@
 
   public static class StandardNeuron extends
       Neuron<Synapse<DoubleWritable, DoubleWritable>> {
-    private double learningRate;
-    private double momentum;
-
-    @Override
-    public void setup(HamaConfiguration conf) {
-      this.learningRate = conf.getDouble("mlp.learning.rate", 0.5);
-      this.momentum = conf.getDouble("mlp.momentum.weight", 0.2);
-    }
 
     @Override
     public void forward(
@@ -63,8 +55,8 @@
         this.backpropagate(gradient);
 
         // Weight corrections
-        double weight = -learningRate * this.getOutput() * m.getDelta()
-            + momentum * m.getPrevWeight();
+        double weight = -this.getLearningRate() * this.getOutput()
+            * m.getDelta() + this.getMomentumWeight() * m.getPrevWeight();
         this.push(weight);
       }
     }
@@ -88,9 +80,9 @@
     job.setConvergenceCheckInterval(1000);
     job.setBatchSize(300);
 
-    job.inputLayer(features, Sigmoid.class);
-    job.addLayer(features, Sigmoid.class);
-    job.outputLayer(labels, Sigmoid.class);
+    job.inputLayer(features, Sigmoid.class, StandardNeuron.class);
+    job.addLayer(features, Sigmoid.class, StandardNeuron.class);
+    job.outputLayer(labels, Sigmoid.class, StandardNeuron.class);
 
     job.setCostFunction(CrossEntropy.class);
 
@@ -101,9 +93,12 @@
       InterruptedException, ClassNotFoundException {
     if (args.length < 9) {
       System.out
-          .println("Usage: <MODEL_PATH> <INPUT_PATH> <LEARNING_RATE> <MOMEMTUM_WEIGHT> <REGULARIZATION_WEIGHT> <FEATURE_DIMENSION> <LABEL_DIMENSION> <MAX_ITERATION> <NUM_TASKS>");
+          .println("Usage: <MODEL_PATH> <INPUT_PATH> "
+              + "<LEARNING_RATE> <MOMEMTUM_WEIGHT> <REGULARIZATION_WEIGHT> "
+              + "<FEATURE_DIMENSION> <LABEL_DIMENSION> <MAX_ITERATION> <NUM_TASKS>");
       System.exit(1);
     }
+
     HornJob ann = createJob(new HamaConfiguration(), args[0], args[1],
         Double.parseDouble(args[2]), Double.parseDouble(args[3]),
         Double.parseDouble(args[4]), Integer.parseInt(args[5]),
diff --git a/src/main/java/org/apache/horn/utils/MNISTConverter.java b/src/main/java/org/apache/horn/utils/MNISTConverter.java
index 99742d6..224fc4b 100644
--- a/src/main/java/org/apache/horn/utils/MNISTConverter.java
+++ b/src/main/java/org/apache/horn/utils/MNISTConverter.java
@@ -68,6 +68,7 @@
     HamaConfiguration conf = new HamaConfiguration();
     FileSystem fs = FileSystem.get(conf);
 
+    @SuppressWarnings("deprecation")
     SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(
         output), LongWritable.class, VectorWritable.class);
 
@@ -81,6 +82,8 @@
           new DenseDoubleVector(vals)));
     }
     
+    imagesIn.close();
+    labelsIn.close();
     writer.close();
   }
 }
diff --git a/src/test/java/org/apache/horn/core/TestNeuron.java b/src/test/java/org/apache/horn/core/TestNeuron.java
index f2fe4e1..5f2bb59 100644
--- a/src/test/java/org/apache/horn/core/TestNeuron.java
+++ b/src/test/java/org/apache/horn/core/TestNeuron.java
@@ -24,13 +24,10 @@
 import junit.framework.TestCase;
 
 import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hama.HamaConfiguration;
-import org.apache.horn.core.Neuron;
-import org.apache.horn.core.Synapse;
 import org.apache.horn.funcs.Sigmoid;
 
 public class TestNeuron extends TestCase {
-  private static double learningRate = 0.1;
+  private static double learningrate = 0.1;
   private static double bias = -1;
   private static double theta = 0.8;
 
@@ -38,10 +35,6 @@
       Neuron<Synapse<DoubleWritable, DoubleWritable>> {
 
     @Override
-    public void setup(HamaConfiguration conf) {
-    }
-
-    @Override
     public void forward(
         Iterable<Synapse<DoubleWritable, DoubleWritable>> messages)
         throws IOException {
@@ -59,14 +52,16 @@
         throws IOException {
       for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
         // Calculates error gradient for each neuron
-        double gradient = new Sigmoid().applyDerivative(this.getOutput()) * (m.getDelta() * m.getWeight());
+        double gradient = new Sigmoid().applyDerivative(this.getOutput())
+            * (m.getDelta() * m.getWeight());
 
         // Propagates to lower layer
         backpropagate(gradient);
 
         // Weight corrections
-        double weight = learningRate * this.getOutput() * m.getDelta();
-        this.push(weight);
+        double weight = learningrate * this.getOutput() * m.getDelta();
+        assertEquals(-0.006688234848481696, weight);
+        // this.push(weight);
       }
     }
 
@@ -74,10 +69,10 @@
 
   public void testProp() throws IOException {
     List<Synapse<DoubleWritable, DoubleWritable>> x = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
-    x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(
-        1.0), new DoubleWritable(0.5)));
-    x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(
-        1.0), new DoubleWritable(0.4)));
+    x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(1.0),
+        new DoubleWritable(0.5)));
+    x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(1.0),
+        new DoubleWritable(0.4)));
 
     MyNeuron n = new MyNeuron();
     n.forward(x);
@@ -87,7 +82,6 @@
     x.add(new Synapse<DoubleWritable, DoubleWritable>(new DoubleWritable(
         -0.1274), new DoubleWritable(-1.2)));
     n.backward(x);
-    assertEquals(-0.006688234848481696, n.getUpdate());
   }
 
 }