HORN-23: Add softmax function
diff --git a/README.md b/README.md
index cb217f6..b797783 100644
--- a/README.md
+++ b/README.md
@@ -49,9 +49,8 @@
   ..
 
   job.inputLayer(784, Sigmoid.class, StandardNeuron.class);
-  job.addLayer(500, Sigmoid.class, StandardNeuron.class);
-  job.addLayer(500, Sigmoid.class, StandardNeuron.class);
-  job.outputLayer(10, Sigmoid.class, StandardNeuron.class);
+  job.addLayer(100, Sigmoid.class, StandardNeuron.class);
+  job.outputLayer(10, SoftMax.class, StandardNeuron.class);
   job.setCostFunction(CrossEntropy.class);
 ```
 
@@ -59,16 +58,18 @@
 
 Download a MNIST training and label datasets, and convert into a HDFS sequence file with following command:
 ```
- % bin/horn jar horn-0.x.0.jar MNISTConverter train-images.idx3-ubyte train-labels.idx1-ubyte /tmp/mnist.seq 
+ % bin/horn jar horn-0.x.0.jar MNISTConverter \
+   train-images.idx3-ubyte train-labels.idx1-ubyte /tmp/mnist.seq 
 ```
 
-Then, train it with following command (in this example, we used η 0.002, λ 0.1, 100 hidden units, and minibatch 10):
+Then, train it with following command (in this example, we used η 0.01, α 0.9, λ 0.0005, 100 hidden units, and minibatch 10):
 ```
  % bin/horn jar horn-0.x.0.jar MultiLayerPerceptron /tmp/model /tmp/mnist.seq \
-   0.002 0.0 0.1 784 100 10 10 12000
- 
+   0.01 0.9 0.00075 784 100 10 10 12000
 ```
 
+With this default example, you'll reach over the 95% accuracy. The local-mode of multithread-based parallel synchronous SGD will took around 1 hour to train. 
+
 ## High Scalability
 
 The Apache Horn is an Sync and Async hybrid distributed training framework. Within single BSP job, each task group works asynchronously using region barrier synchronization instead of global barrier synchronization, and trains large-scale neural network model using assigned data sets in synchronous way.
diff --git a/bin/horn b/bin/horn
index e697695..8cbd106 100755
--- a/bin/horn
+++ b/bin/horn
@@ -58,7 +58,7 @@
 fi
 
 JAVA=$JAVA_HOME/bin/java
-JAVA_HEAP_MAX=-Xmx512m
+JAVA_HEAP_MAX=-Xmx2048m
 
 # check envvars which might override default args
 if [ "$HORN_HEAPSIZE" != "" ]; then
diff --git a/conf/horn-env.sh b/conf/horn-env.sh
index a033fe0..ca7ed32 100644
--- a/conf/horn-env.sh
+++ b/conf/horn-env.sh
@@ -22,5 +22,5 @@
 # Set environment variables here.
 
 # The java implementation to use.  Required.
-export JAVA_HOME=/usr/lib/jvm/java-8-oracle/
+export JAVA_HOME=/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home
 
diff --git a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
index e415a25..b82ad41 100644
--- a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
@@ -30,7 +30,10 @@
 import org.apache.hama.commons.math.DoubleVector;
 import org.apache.horn.core.Constants.LearningStyle;
 import org.apache.horn.core.Constants.TrainingMethod;
+import org.apache.horn.funcs.CategoricalCrossEntropy;
+import org.apache.horn.funcs.CrossEntropy;
 import org.apache.horn.funcs.FunctionFactory;
+import org.mortbay.log.Log;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
@@ -65,7 +68,7 @@
   protected List<Integer> layerSizeList;
 
   protected TrainingMethod trainingMethod;
-  
+
   protected LearningStyle learningStyle;
 
   public AbstractLayeredNeuralNetwork() {
@@ -77,6 +80,11 @@
 
   public AbstractLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
     super(conf, modelPath);
+    if (this.layerSizeList.get(this.layerSizeList.size() - 1) > 1
+        && this.costFunction.getFunctionName().equalsIgnoreCase(
+            CrossEntropy.class.getSimpleName())) {
+      this.setCostFunction(new CategoricalCrossEntropy());
+    }
   }
 
   /**
@@ -118,11 +126,11 @@
   public TrainingMethod getTrainingMethod() {
     return this.trainingMethod;
   }
-  
+
   public void setLearningStyle(LearningStyle style) {
     this.learningStyle = style;
   }
-  
+
   public LearningStyle getLearningStyle() {
     return this.learningStyle;
   }
diff --git a/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
index 5624e49..77d6af0 100644
--- a/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
@@ -24,7 +24,6 @@
 import java.lang.reflect.InvocationTargetException;
 import java.net.URI;
 import java.net.URISyntaxException;
-import java.util.Map;
 
 import org.apache.commons.lang.SerializationUtils;
 import org.apache.hadoop.conf.Configuration;
diff --git a/src/main/java/org/apache/horn/core/IntermediateOutput.java b/src/main/java/org/apache/horn/core/IntermediateOutput.java
new file mode 100644
index 0000000..272fed0
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/IntermediateOutput.java
@@ -0,0 +1,23 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+
+public abstract class IntermediateOutput implements LayerInterface {
+
+}
diff --git a/src/main/java/org/apache/horn/core/LayerInterface.java b/src/main/java/org/apache/horn/core/LayerInterface.java
new file mode 100644
index 0000000..c010cc9
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/LayerInterface.java
@@ -0,0 +1,28 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+
+import org.apache.hama.commons.math.DoubleVector;
+
+public interface LayerInterface {
+
+  public DoubleVector interlayer(DoubleVector intermediateOutput) throws IOException;
+  
+}
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
index fe5d3a3..d4f2f3e 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
@@ -46,6 +46,7 @@
 import org.apache.horn.core.Constants.LearningStyle;
 import org.apache.horn.core.Constants.TrainingMethod;
 import org.apache.horn.funcs.FunctionFactory;
+import org.apache.horn.funcs.SoftMax;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
@@ -76,7 +77,7 @@
   protected List<DoubleFunction> squashingFunctionList;
 
   protected List<Class<? extends Neuron>> neuronClassList;
-  
+
   protected int finalLayerIdx;
 
   public LayeredNeuralNetwork() {
@@ -97,9 +98,20 @@
    */
   public int addLayer(int size, boolean isFinalLayer,
       DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass) {
+    return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null);
+  }
+
+  public int addLayer(int size, boolean isFinalLayer,
+      DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass,
+      Class<? extends IntermediateOutput> interlayer) {
     Preconditions.checkArgument(size > 0,
         "Size of layer must be larger than 0.");
     if (!isFinalLayer) {
+      if (this.layerSizeList.size() == 0) {
+        LOG.info("add input layer: " + size + " neurons");
+      } else {
+        LOG.info("add hidden layer: " + size + " neurons");
+      }
       size += 1;
     }
 
@@ -107,6 +119,7 @@
     int layerIdx = this.layerSizeList.size() - 1;
     if (isFinalLayer) {
       this.finalLayerIdx = layerIdx;
+      LOG.info("add output layer: " + size + " neurons");
     }
 
     // add weights between current layer and previous layer, and input layer has
@@ -133,6 +146,7 @@
       this.weightMatrixList.add(weightMatrix);
       this.prevWeightUpdatesList.add(new DenseDoubleMatrix(row, col));
       this.squashingFunctionList.add(squashingFunction);
+
       this.neuronClassList.add(neuronClass);
     }
     return layerIdx;
@@ -152,6 +166,7 @@
 
   /**
    * Set the previous weight matrices.
+   * 
    * @param prevUpdates
    */
   void setPrevWeightMatrices(DoubleMatrix[] prevUpdates) {
@@ -263,12 +278,12 @@
     for (Class<? extends Neuron> clazz : this.neuronClassList) {
       output.writeUTF(clazz.getName());
     }
-    
+
     // write squashing functions
     output.writeInt(this.squashingFunctionList.size());
     for (DoubleFunction aSquashingFunctionList : this.squashingFunctionList) {
-      WritableUtils.writeString(output, aSquashingFunctionList
-              .getFunctionName());
+      WritableUtils.writeString(output,
+          aSquashingFunctionList.getFunctionName());
     }
 
     // write weight matrices
@@ -327,21 +342,18 @@
     outputCache.add(intermediateOutput);
 
     for (int i = 0; i < this.layerSizeList.size() - 1; ++i) {
-      intermediateOutput = forward(i, intermediateOutput);
-      outputCache.add(intermediateOutput);
+      forward(i, outputCache);
     }
     return outputCache;
   }
-  
+
   /**
    * @param neuronClass
    * @return a new neuron instance
    */
-  @SuppressWarnings({ "unchecked", "rawtypes" })
-  public static Neuron<Synapse<DoubleWritable, DoubleWritable>> newNeuronInstance(
-      Class<? extends Neuron> neuronClass) {
-    return (Neuron<Synapse<DoubleWritable, DoubleWritable>>) ReflectionUtils
-        .newInstance(neuronClass);
+  @SuppressWarnings({ "rawtypes" })
+  public static Neuron newNeuronInstance(Class<? extends Neuron> neuronClass) {
+    return (Neuron) ReflectionUtils.newInstance(neuronClass);
   }
 
   /**
@@ -351,25 +363,33 @@
    * @param intermediateOutput The intermediateOutput of previous layer.
    * @return a new vector with the result of the operation.
    */
-  protected DoubleVector forward(int fromLayer, DoubleVector intermediateOutput) {
+  protected void forward(int fromLayer, List<DoubleVector> outputCache) {
+    DoubleVector previousOutput = outputCache.get(fromLayer * 2); // skip
+                                                                  // intermediate
+                                                                  // output
+
     DoubleMatrix weightMatrix = this.weightMatrixList.get(fromLayer);
 
     // LOG.info("intermediate: " + intermediateOutput.toString());
     // DoubleVector vec = weightMatrix.multiplyVectorUnsafe(intermediateOutput);
     // vec = vec.applyToElements(this.squashingFunctionList.get(fromLayer));
-   
+
+    DoubleFunction squashingFunction = getSquashingFunction(fromLayer);
+
     DoubleVector vec = new DenseDoubleVector(weightMatrix.getRowCount());
+
     for (int row = 0; row < weightMatrix.getRowCount(); row++) {
       List<Synapse<DoubleWritable, DoubleWritable>> msgs = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
       for (int col = 0; col < weightMatrix.getColumnCount(); col++) {
         msgs.add(new Synapse<DoubleWritable, DoubleWritable>(
-            new DoubleWritable(intermediateOutput.get(col)),
-            new DoubleWritable(weightMatrix.get(row, col))));
+            new DoubleWritable(previousOutput.get(col)), new DoubleWritable(
+                weightMatrix.get(row, col))));
       }
       Iterable<Synapse<DoubleWritable, DoubleWritable>> iterable = msgs;
-      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
-          .get(fromLayer));
-      n.setSquashingFunction(this.squashingFunctionList.get(fromLayer));
+      Neuron n = newNeuronInstance(this.neuronClassList.get(fromLayer));
+      n.setSquashingFunction(squashingFunction);
+      n.setLayerIndex(fromLayer);
+
       try {
         n.forward(iterable);
       } catch (IOException e) {
@@ -378,14 +398,30 @@
       }
       vec.set(row, n.getOutput());
     }
-    
+
+    if (squashingFunction.getFunctionName().equalsIgnoreCase(
+        SoftMax.class.getSimpleName())) {
+      IntermediateOutput interlayer = (IntermediateOutput) ReflectionUtils
+          .newInstance(SoftMax.SoftMaxOutputComputer.class);
+      try {
+        outputCache.add(vec);
+        vec = interlayer.interlayer(vec);
+      } catch (IOException e) {
+        // TODO Auto-generated catch block
+        e.printStackTrace();
+      }
+    } else {
+      outputCache.add(null);
+    }
+
     // add bias
     DoubleVector vecWithBias = new DenseDoubleVector(vec.getDimension() + 1);
     vecWithBias.set(0, 1);
     for (int i = 0; i < vec.getDimension(); ++i) {
       vecWithBias.set(i + 1, vec.get(i));
     }
-    return vecWithBias;
+
+    outputCache.add(vecWithBias);
   }
 
   /**
@@ -472,6 +508,7 @@
       List<DoubleVector> internalResults) {
 
     DoubleVector output = internalResults.get(internalResults.size() - 1);
+
     // initialize weight update matrices
     DenseDoubleMatrix[] weightUpdateMatrices = new DenseDoubleMatrix[this.weightMatrixList
         .size()];
@@ -487,22 +524,27 @@
 
     DoubleMatrix lastWeightMatrix = this.weightMatrixList
         .get(this.weightMatrixList.size() - 1);
+
     for (int i = 0; i < deltaVec.getDimension(); ++i) {
       double costFuncDerivative = this.costFunction.applyDerivative(
           labels.get(i), output.get(i + 1));
       // add regularization
       costFuncDerivative += this.regularizationWeight
           * lastWeightMatrix.getRowVector(i).sum();
-      deltaVec.set(
-          i,
-          costFuncDerivative
-              * squashingFunction.applyDerivative(output.get(i + 1)));
+
+      if (!squashingFunction.getFunctionName().equalsIgnoreCase(
+          SoftMax.class.getSimpleName())) {
+        costFuncDerivative *= squashingFunction.applyDerivative(output
+            .get(i + 1));
+      }
+
+      deltaVec.set(i, costFuncDerivative);
     }
 
     // start from previous layer of output layer
     for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
-      output = internalResults.get(layer);
-      deltaVec = backpropagate(layer, deltaVec, internalResults,
+      output = internalResults.get(layer * 2); // skip intermediate output
+      deltaVec = backpropagate(layer, deltaVec, output,
           weightUpdateMatrices[layer]);
     }
 
@@ -521,13 +563,12 @@
    * @return the squashing function of the specified position.
    */
   private DoubleVector backpropagate(int curLayerIdx,
-      DoubleVector nextLayerDelta, List<DoubleVector> outputCache,
+      DoubleVector nextLayerDelta, DoubleVector curLayerOutput,
       DenseDoubleMatrix weightUpdateMatrix) {
 
     // get layer related information
     DoubleFunction squashingFunction = this.squashingFunctionList
         .get(curLayerIdx);
-    DoubleVector curLayerOutput = outputCache.get(curLayerIdx);
     DoubleMatrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
     DoubleMatrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);
 
@@ -536,16 +577,16 @@
       nextLayerDelta = nextLayerDelta.slice(1,
           nextLayerDelta.getDimension() - 1);
     }
-    
+
     DoubleVector deltaVector = new DenseDoubleVector(
         weightMatrix.getColumnCount());
-    
+
     for (int row = 0; row < weightMatrix.getColumnCount(); ++row) {
-      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
-          .get(curLayerIdx));
+      Neuron n = newNeuronInstance(this.neuronClassList.get(curLayerIdx));
       // calls setup method
       n.setLearningRate(this.learningRate);
       n.setMomentumWeight(this.momentumWeight);
+      n.setLayerIndex(curLayerIdx);
 
       n.setSquashingFunction(squashingFunction);
       n.setOutput(curLayerOutput.get(row));
@@ -568,7 +609,7 @@
         // TODO Auto-generated catch block
         e.printStackTrace();
       }
-      
+
       // update weights
       weightUpdateMatrix.setColumn(row, n.getWeights());
       deltaVector.set(row, n.getDelta());
@@ -628,5 +669,5 @@
   public DoubleFunction getSquashingFunction(int idx) {
     return this.squashingFunctionList.get(idx);
   }
-  
+
 }
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
index ce6d6e4..350200f 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
@@ -18,6 +18,9 @@
 package org.apache.horn.core;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -41,8 +44,9 @@
     extends
     BSP<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> {
 
-  private static final Log LOG = LogFactory.getLog(LayeredNeuralNetworkTrainer.class);
-  
+  private static final Log LOG = LogFactory
+      .getLog(LayeredNeuralNetworkTrainer.class);
+
   private LayeredNeuralNetwork inMemoryModel;
   private HamaConfiguration conf;
   /* Default batch size */
@@ -90,10 +94,9 @@
     // write model to modelPath
     if (peer.getPeerIndex() == 0) {
       try {
-        LOG.info(String.format("End of training, number of iterations: %d.\n",
+        LOG.info(String.format("End of training, number of iterations: %d.",
             this.iterations));
-        LOG.info(String.format("Write model back to %s\n",
-            inMemoryModel.getModelPath()));
+        LOG.info(String.format("Write model back to %s", inMemoryModel.getModelPath()));
         this.inMemoryModel.writeModelToFile();
       } catch (IOException e) {
         e.printStackTrace();
@@ -101,10 +104,21 @@
     }
   }
 
+  private List<DoubleVector> trainingSet = new ArrayList<DoubleVector>();
+  private Random r = new Random();
+
   @Override
   public void bsp(
       BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
       throws IOException, SyncException, InterruptedException {
+    // load local data into memory
+    LongWritable key = new LongWritable();
+    VectorWritable value = new VectorWritable();
+    while (peer.readNext(key, value)) {
+      DoubleVector v = value.getVector();
+      trainingSet.add(v);
+    }
+
     while (this.iterations++ < maxIterations) {
       // each groom calculate the matrices updates according to local data
       calculateUpdates(peer);
@@ -121,6 +135,10 @@
     }
   }
 
+  private DoubleVector getRandomInstance() {
+    return trainingSet.get(r.nextInt(trainingSet.size()));
+  }
+
   /**
    * Calculate the matrices updates according to local partition of data.
    * 
@@ -154,18 +172,8 @@
 
     // continue to train
     double avgTrainingError = 0.0;
-    LongWritable key = new LongWritable();
-    VectorWritable value = new VectorWritable();
     for (int recordsRead = 0; recordsRead < batchSize; ++recordsRead) {
-      if (!peer.readNext(key, value)) {
-        peer.reopenInput();
-        if (peer.getPeerIndex() == 0) {
-          epoch++;
-          LOG.info("Training loss: " + curAvgTrainingError + " at " + (epoch) + " epoch.");
-        }
-        peer.readNext(key, value);
-      }
-      DoubleVector trainingInstance = value.getVector();
+      DoubleVector trainingInstance = getRandomInstance();
       LayeredNeuralNetwork.matricesAdd(weightUpdates,
           this.inMemoryModel.trainByInstance(trainingInstance));
       avgTrainingError += this.inMemoryModel.trainingError;
@@ -179,8 +187,8 @@
 
     DoubleMatrix[] prevWeightUpdates = this.inMemoryModel
         .getPrevMatricesUpdates();
-    ParameterMessage outMessage = new ParameterMessage(
-        avgTrainingError, false, weightUpdates, prevWeightUpdates);
+    ParameterMessage outMessage = new ParameterMessage(avgTrainingError, false,
+        weightUpdates, prevWeightUpdates);
     peer.send(peer.getPeerName(0), outMessage);
   }
 
@@ -215,7 +223,7 @@
         LayeredNeuralNetwork.matricesAdd(prevMatricesUpdates,
             message.getPrevMatrices());
       }
-      
+
       avgTrainingError += message.getTrainingError();
     }
 
@@ -229,7 +237,7 @@
 
     this.inMemoryModel.updateWeightMatrices(matricesUpdates);
     this.inMemoryModel.setPrevWeightMatrices(prevMatricesUpdates);
-    
+
     // check convergence
     if (iterations % convergenceCheckInterval == 0) {
       if (prevAvgTrainingError < curAvgTrainingError) {
@@ -238,14 +246,16 @@
       }
       // update
       prevAvgTrainingError = curAvgTrainingError;
+      LOG.info("Training error: " + curAvgTrainingError + " at " + (iterations)
+          + " iteration.");
       curAvgTrainingError = 0;
     }
     curAvgTrainingError += avgTrainingError / convergenceCheckInterval;
 
     // broadcast updated weight matrices
     for (String peerName : peer.getAllPeerNames()) {
-      ParameterMessage msg = new ParameterMessage(
-          0, isConverge, this.inMemoryModel.getWeightMatrices(),
+      ParameterMessage msg = new ParameterMessage(0, isConverge,
+          this.inMemoryModel.getWeightMatrices(),
           this.inMemoryModel.getPrevMatricesUpdates());
       peer.send(peerName, msg);
     }
diff --git a/src/main/java/org/apache/horn/core/Neuron.java b/src/main/java/org/apache/horn/core/Neuron.java
index 4471b45..af18c79 100644
--- a/src/main/java/org/apache/horn/core/Neuron.java
+++ b/src/main/java/org/apache/horn/core/Neuron.java
@@ -25,6 +25,7 @@
 import org.apache.hama.commons.math.DoubleFunction;
 
 public abstract class Neuron<M extends Writable> implements Writable, NeuronInterface<M> {
+  int id;
   double output;
   double weight;
   double delta;
@@ -32,8 +33,27 @@
   double momentumWeight;
   double learningRate;
 
+  int layerIndex;
+  boolean isOutputLayer;
+  
   protected DoubleFunction squashingFunction;
 
+  public void setNeuronID(int id) {
+    this.id = id;
+  }
+  
+  public int getID() {
+    return id;
+  }
+  
+  public int getLayerIndex() {
+    return layerIndex;
+  }
+
+  public void setLayerIndex(int index) {
+    this.layerIndex = index;
+  }
+  
   public void feedforward(double sum) {
     this.output = sum;
   }
@@ -103,6 +123,7 @@
 
   @Override
   public void readFields(DataInput in) throws IOException {
+    id = in.readInt();
     output = in.readDouble();
     weight = in.readDouble();
     delta = in.readDouble();
@@ -113,6 +134,7 @@
 
   @Override
   public void write(DataOutput out) throws IOException {
+    out.writeInt(id);
     out.writeDouble(output);
     out.writeDouble(weight);
     out.writeDouble(delta);
diff --git a/src/main/java/org/apache/horn/core/NeuronInterface.java b/src/main/java/org/apache/horn/core/NeuronInterface.java
index ef5a2d3..73d8220 100644
--- a/src/main/java/org/apache/horn/core/NeuronInterface.java
+++ b/src/main/java/org/apache/horn/core/NeuronInterface.java
@@ -25,7 +25,7 @@
 
   /**
    * This method is called when the messages are propagated from the next layer.
-   * It can be used to determine if the neuron would activate, or fire.
+   * It can be used to calculate the activation or intermediate output.
    * 
    * @param messages
    * @throws IOException
diff --git a/src/main/java/org/apache/horn/core/Synapse.java b/src/main/java/org/apache/horn/core/Synapse.java
index 714767b..6dbada8 100644
--- a/src/main/java/org/apache/horn/core/Synapse.java
+++ b/src/main/java/org/apache/horn/core/Synapse.java
@@ -69,7 +69,7 @@
   public double getPrevWeight() {
     return prevWeight.get();
   }
-
+  
   @Override
   public void readFields(DataInput in) throws IOException {
     message.readFields(in);
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
index 4c0df95..ac17cc4 100644
--- a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -27,6 +27,7 @@
 import org.apache.horn.core.Synapse;
 import org.apache.horn.funcs.CrossEntropy;
 import org.apache.horn.funcs.Sigmoid;
+import org.apache.horn.funcs.SoftMax;
 
 public class MultiLayerPerceptron {
 
@@ -41,8 +42,7 @@
       for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
         sum += m.getInput() * m.getWeight();
       }
-
-      this.feedforward(this.squashingFunction.apply(sum));
+      this.feedforward(squashingFunction.apply(sum));
     }
 
     @Override
@@ -61,7 +61,7 @@
       }
 
       this.backpropagate(gradient
-          * this.squashingFunction.applyDerivative(this.getOutput()));
+          * squashingFunction.applyDerivative(getOutput()));
     }
   }
 
@@ -78,15 +78,15 @@
     job.setLearningRate(learningRate);
     job.setMomentumWeight(momemtumWeight);
     job.setRegularizationWeight(regularizationWeight);
-    
+
     job.setConvergenceCheckInterval(600);
     job.setBatchSize(miniBatch);
-    
+
     job.setTrainingMethod(TrainingMethod.GRADIENT_DESCENT);
 
     job.inputLayer(features, Sigmoid.class, StandardNeuron.class);
     job.addLayer(hu, Sigmoid.class, StandardNeuron.class);
-    job.outputLayer(labels, Sigmoid.class, StandardNeuron.class);
+    job.outputLayer(labels, SoftMax.class, StandardNeuron.class);
 
     job.setCostFunction(CrossEntropy.class);
 
diff --git a/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java b/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
new file mode 100644
index 0000000..96c228a
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleDoubleFunction;
+
+/**
+ * for softmaxed output 
+ */
+public class CategoricalCrossEntropy extends DoubleDoubleFunction {
+  
+  private static final double epsilon = 1e-8;
+  
+  @Override
+  public double apply(double target, double actual) {
+    return -target * Math.log(Math.max(actual, epsilon));
+  }
+
+  @Override
+  public double applyDerivative(double target, double actual) {
+    // o - y
+    return -(target - actual);
+  }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/CrossEntropy.java b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
index 7cc5e6a..a096be0 100644
--- a/src/main/java/org/apache/horn/funcs/CrossEntropy.java
+++ b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
@@ -29,6 +29,8 @@
  */
 public class CrossEntropy extends DoubleDoubleFunction {
 
+  private static final double epsilon = 1e-8;
+  
   @Override
   public double apply(double target, double actual) {
     double adjustedTarget = (target == 0 ? 0.000001 : target);
@@ -36,10 +38,11 @@
     double adjustedActual = (actual == 0 ? 0.000001 : actual);
     adjustedActual = (actual == 1 ? 0.999999 : adjustedActual);
     
-    return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget)
-        * Math.log(1 - adjustedActual);
+    return -target * Math.log(Math.max(actual, epsilon)) - (1 - target)
+        * Math.log(Math.max(1 - actual, epsilon));
+    // return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget) *  Math.log(adjustedActual);
   }
-
+  
   @Override
   public double applyDerivative(double target, double actual) {
     double adjustedTarget = (target == 0 ? 0.000001 : target);
diff --git a/src/main/java/org/apache/horn/funcs/FunctionFactory.java b/src/main/java/org/apache/horn/funcs/FunctionFactory.java
index 9b38a0d..4310a95 100644
--- a/src/main/java/org/apache/horn/funcs/FunctionFactory.java
+++ b/src/main/java/org/apache/horn/funcs/FunctionFactory.java
@@ -37,6 +37,10 @@
       return new Sigmoid();
     } else if (functionName.equalsIgnoreCase(Tanh.class.getSimpleName())) {
       return new Tanh();
+    } else if (functionName.equalsIgnoreCase(ReLU.class.getSimpleName())) {
+      return new ReLU();
+    } else if (functionName.equalsIgnoreCase(SoftMax.class.getSimpleName())) {
+      return new SoftMax();
     } else if (functionName.equalsIgnoreCase(IdentityFunction.class
         .getSimpleName())) {
       return new IdentityFunction();
@@ -59,7 +63,10 @@
     } else if (functionName
         .equalsIgnoreCase(CrossEntropy.class.getSimpleName())) {
       return new CrossEntropy();
-    }
+    } else if (functionName
+        .equalsIgnoreCase(CategoricalCrossEntropy.class.getSimpleName())) {
+      return new CategoricalCrossEntropy();
+    } 
 
     throw new IllegalArgumentException(String.format(
         "No double double function with name '%s' exists.", functionName));
diff --git a/src/main/java/org/apache/horn/funcs/ReLU.java b/src/main/java/org/apache/horn/funcs/ReLU.java
index 425137f..85af867 100644
--- a/src/main/java/org/apache/horn/funcs/ReLU.java
+++ b/src/main/java/org/apache/horn/funcs/ReLU.java
@@ -30,12 +30,15 @@
 
   @Override
   public double apply(double value) {
-    return Math.max(0, value);
+    return Math.max(0.001, value);
   }
 
   @Override
   public double applyDerivative(double value) {
-    return (value > Double.MIN_VALUE) ? 1 : 0;
+    if (value > 0)
+      return 0.999;
+    else
+      return 0.001;
   }
 
 }
diff --git a/src/main/java/org/apache/horn/funcs/Sigmoid.java b/src/main/java/org/apache/horn/funcs/Sigmoid.java
index cc393e3..bcccf76 100644
--- a/src/main/java/org/apache/horn/funcs/Sigmoid.java
+++ b/src/main/java/org/apache/horn/funcs/Sigmoid.java
@@ -30,6 +30,11 @@
 
   @Override
   public double apply(double value) {
+    if(value > 100) { // to avoid overflow and underflow
+      return 0.9999;
+    } else if (value < -100) {
+      return 0.0001;
+    }
     return 1.0 / (1 + Math.exp(-value));
   }
 
diff --git a/src/main/java/org/apache/horn/funcs/SoftMax.java b/src/main/java/org/apache/horn/funcs/SoftMax.java
new file mode 100644
index 0000000..6e0bf76
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/SoftMax.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import java.io.IOException;
+
+import org.apache.hama.commons.math.DenseDoubleVector;
+import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.DoubleVector;
+import org.apache.horn.core.IntermediateOutput;
+
+public class SoftMax extends DoubleFunction {
+
+  @Override
+  public double apply(double value) {
+    // it will be handled by intermediate output handler
+    return value;
+  }
+
+  @Override
+  public double applyDerivative(double value) {
+    return value * (1d - value);
+  }
+  
+  public static class SoftMaxOutputComputer extends IntermediateOutput {
+
+    @Override
+    public DoubleVector interlayer(DoubleVector output) throws IOException {
+      DoubleVector expVec = new DenseDoubleVector(output.getDimension());
+      double sum = 0.0;
+      for(int i = 0; i < output.getDimension(); ++i) {
+        double exp = Math.exp(output.get(i));
+        sum += exp;
+        expVec.set(i, exp);
+      }
+      // divide by the sum of exponential of the whole vector
+      DoubleVector softmaxed = expVec.divide(sum);
+      return softmaxed;
+    }
+
+  }
+
+}
diff --git a/src/main/java/org/apache/horn/utils/MNISTEvaluator.java b/src/main/java/org/apache/horn/utils/MNISTEvaluator.java
index a5b68e0..839be97 100644
--- a/src/main/java/org/apache/horn/utils/MNISTEvaluator.java
+++ b/src/main/java/org/apache/horn/utils/MNISTEvaluator.java
@@ -21,6 +21,7 @@
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
+import java.util.Random;
 
 import org.apache.hama.HamaConfiguration;
 import org.apache.hama.commons.math.DenseDoubleVector;
@@ -30,11 +31,11 @@
 public class MNISTEvaluator {
 
   private static int PIXELS = 28 * 28;
-  
+
   private static double rescale(double x) {
     return 1 - (255 - x) / 255;
   }
-  
+
   public static void main(String[] args) throws IOException {
     if (args.length < 3) {
       System.out.println("Usage: <TRAINED_MODEL> <TEST_IMAGES> <TEST_LABELS>");
@@ -51,15 +52,13 @@
         new File(training_data)));
     DataInputStream labelsIn = new DataInputStream(new FileInputStream(
         new File(labels_data)));
-    
+
     imagesIn.readInt(); // Magic number
     int count = imagesIn.readInt();
     labelsIn.readInt(); // Magic number
     labelsIn.readInt(); // Count
     imagesIn.readInt(); // Rows
     imagesIn.readInt(); // Cols
-    
-    System.out.println("Evaluating " + count + " images");
 
     byte[][] images = new byte[count][PIXELS];
     byte[] labels = new byte[count];
@@ -70,28 +69,33 @@
 
     HamaConfiguration conf = new HamaConfiguration();
     LayeredNeuralNetwork ann = new LayeredNeuralNetwork(conf, modelPath);
-    
-    int correct = 0;
-    for (int i = 0; i < count; i++) {
-      double[] vals = new double[PIXELS];
-      for (int j = 0; j < PIXELS; j++) {
-        vals[j] = rescale((images[i][j] & 0xff));
-      }
-      int label = (labels[i] & 0xff);
 
-      DoubleVector instance = new DenseDoubleVector(vals);
-      DoubleVector result = ann.getOutput(instance);
-      
-      if(getNumber(result) == label) {
-        correct++;
+    Random generator = new Random();
+    int correct = 0;
+    int total = 0;
+    for (int i = 0; i < count; i++) {
+      if (generator.nextInt(10) == 1) {
+        double[] vals = new double[PIXELS];
+        for (int j = 0; j < PIXELS; j++) {
+          vals[j] = rescale((images[i][j] & 0xff));
+        }
+        int label = (labels[i] & 0xff);
+
+        DoubleVector instance = new DenseDoubleVector(vals);
+        DoubleVector result = ann.getOutput(instance);
+
+        if (getNumber(result) == label) {
+          correct++;
+        }
+        total++;
       }
     }
 
-    System.out.println((double) correct / count);
+    System.out.println(((double) correct / total * 100) + "%");
     // TODO System.out.println("Precision = " + (tp / (tp + fp)));
-    //System.out.println("Recall = " + (tp / (tp + fn)));
-    //System.out.println("Accuracy = " + ((tp + tn) / (tp + tn + fp + fn)));
-    
+    // System.out.println("Recall = " + (tp / (tp + fn)));
+    // System.out.println("Accuracy = " + ((tp + tn) / (tp + tn + fp + fn)));
+
     imagesIn.close();
     labelsIn.close();
   }
@@ -99,9 +103,9 @@
   private static int getNumber(DoubleVector result) {
     double max = 0;
     int index = -1;
-    for(int x = 0; x < result.getLength(); x++) {
+    for (int x = 0; x < result.getLength(); x++) {
       double curr = result.get(x);
-      if(max < curr) {
+      if (max < curr) {
         max = curr;
         index = x;
       }
diff --git a/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java b/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java
index bb404d6..9110088 100644
--- a/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java
+++ b/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java
@@ -23,8 +23,6 @@
 import java.io.InputStreamReader;
 import java.net.URI;
 
-import junit.framework.TestCase;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.fs.FileSystem;
@@ -32,18 +30,24 @@
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hama.Constants;
+import org.apache.hama.HamaCluster;
 import org.apache.hama.HamaConfiguration;
 import org.apache.hama.commons.io.VectorWritable;
 import org.apache.hama.commons.math.DenseDoubleVector;
 import org.apache.hama.commons.math.DoubleVector;
 import org.apache.horn.core.HornJob;
 import org.apache.horn.core.LayeredNeuralNetwork;
+import org.apache.horn.core.Constants.TrainingMethod;
+import org.apache.horn.examples.MultiLayerPerceptron.StandardNeuron;
+import org.apache.horn.funcs.CrossEntropy;
+import org.apache.horn.funcs.Sigmoid;
 
 /**
  * Test the functionality of NeuralNetwork Example.
  */
-public class MultiLayerPerceptronTest extends TestCase { // HamaCluster {
-  private static final Log LOG = LogFactory.getLog(MultiLayerPerceptronTest.class);
+public class MultiLayerPerceptronTest extends HamaCluster {
+  private static final Log LOG = LogFactory
+      .getLog(MultiLayerPerceptronTest.class);
   private HamaConfiguration conf;
   private FileSystem fs;
   private String MODEL_PATH = "/tmp/neuralnets.model";
@@ -51,7 +55,7 @@
   private String SEQTRAIN_DATA = "/tmp/test-neuralnets.data";
 
   public MultiLayerPerceptronTest() {
-    conf = new HamaConfiguration();/*
+    conf = new HamaConfiguration();
     conf.set("bsp.master.address", "localhost");
     conf.setBoolean("hama.child.redirect.log.console", true);
     conf.setBoolean("hama.messenger.runtime.compression", false);
@@ -62,7 +66,7 @@
     conf.setInt(Constants.ZOOKEEPER_CLIENT_PORT, 21810);
     conf.set("hama.sync.client.class",
         org.apache.hama.bsp.sync.ZooKeeperSyncClientImpl.class
-            .getCanonicalName());*/
+            .getCanonicalName());
   }
 
   @Override
@@ -163,12 +167,28 @@
     }
 
     try {
-      HornJob ann = MultiLayerPerceptron.createJob(conf, MODEL_PATH,
-          SEQTRAIN_DATA, 0.4, 0.2, 0.01, featureDimension, 8, labelDimension,
-          300, 10000);
+      HornJob job = new HornJob(conf, MultiLayerPerceptronTest.class);
+      job.setTrainingSetPath(SEQTRAIN_DATA);
+      job.setModelPath(MODEL_PATH);
 
+      job.setMaxIteration(1000);
+      job.setLearningRate(0.4);
+      job.setMomentumWeight(0.2);
+      job.setRegularizationWeight(0.001);
+
+      job.setConvergenceCheckInterval(100);
+      job.setBatchSize(300);
+
+      job.setTrainingMethod(TrainingMethod.GRADIENT_DESCENT);
+
+      job.inputLayer(featureDimension, Sigmoid.class, StandardNeuron.class);
+      job.addLayer(featureDimension, Sigmoid.class, StandardNeuron.class);
+      job.outputLayer(labelDimension, Sigmoid.class, StandardNeuron.class);
+
+      job.setCostFunction(CrossEntropy.class);
+      
       long startTime = System.currentTimeMillis();
-      if (ann.waitForCompletion(true)) {
+      if (job.waitForCompletion(true)) {
         LOG.info("Job Finished in " + (System.currentTimeMillis() - startTime)
             / 1000.0 + " seconds");
       }