Add DistBeliefModelTrainer
diff --git a/src/main/java/org/apache/horn/distbelief/DistBeliefModelTrainer.java b/src/main/java/org/apache/horn/distbelief/DistBeliefModelTrainer.java
new file mode 100644
index 0000000..a19139c
--- /dev/null
+++ b/src/main/java/org/apache/horn/distbelief/DistBeliefModelTrainer.java
@@ -0,0 +1,87 @@
+package org.apache.horn.distbelief;
+
+import java.io.IOException;
+
+import org.apache.hama.bsp.BSP;
+import org.apache.hama.bsp.BSPPeer;
+import org.apache.hama.bsp.sync.SyncException;
+
+/**
+ * This DistBeliefModelTrainer performs each SGD.
+ */
+public class DistBeliefModelTrainer extends BSP {
+
+ private boolean isConverge = false;
+ private int iterations;
+ private int maxIterations;
+
+ @Override
+ public final void setup(BSPPeer peer) {
+ // loads subset of neural network model replica into memory
+ }
+
+ @Override
+ public void bsp(BSPPeer peer) throws IOException, SyncException,
+ InterruptedException {
+
+ // Iterate until reach max iteration or convergence
+ while (this.iterations++ < maxIterations) {
+
+ // Fetch latest parameters
+ fetchParameters(peer);
+
+ // Perform mini-batch
+ doMinibatch(peer);
+
+ // Push parameters
+ pushParameters(peer);
+
+ if (this.isConverge) {
+ break;
+ }
+ }
+
+ }
+
+ /**
+ * Performs the mini-batch
+ * @param peer
+ */
+ private void doMinibatch(BSPPeer peer) {
+ double avgTrainingError = 0.0;
+ // 1. loads a set of mini-batch instances from assigned splits into memory
+
+ // 2. train incrementally from a mini-batch of instances
+ /*
+ for (Instance trainingInstance : MiniBatchSet) {
+
+ // 2.1 upward propagation (start from the input layer)
+ for (Neuron neuron : neurons) {
+ neuron.upward(msg);
+ sync();
+ }
+
+ // calculate total error
+ sync();
+
+ // 2.3 downward propagation (start from the total error)
+ for (Neuron neuron : neurons) {
+ neuron.downward(msg);
+ sync();
+ }
+
+ // calculate the the average training error
+ }
+ */
+
+ }
+
+ private void fetchParameters(BSPPeer peer) {
+ // TODO fetch latest weights from the parameter server
+ }
+
+ private void pushParameters(BSPPeer peer) {
+ // TODO push updated weights
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/distbelief/Neuron.java b/src/main/java/org/apache/horn/distbelief/Neuron.java
index ce67cf2..fadb522 100644
--- a/src/main/java/org/apache/horn/distbelief/Neuron.java
+++ b/src/main/java/org/apache/horn/distbelief/Neuron.java
@@ -23,6 +23,10 @@
double output;
double weight;
+ public void propagate(double gradient) {
+ // TODO Auto-generated method stub
+ }
+
public void setOutput(double output) {
this.output = output;
}
@@ -32,6 +36,7 @@
}
public void push(double weight) {
+ // TODO Auto-generated method stub
this.weight = weight;
}
diff --git a/src/main/java/org/apache/horn/distbelief/PropMessage.java b/src/main/java/org/apache/horn/distbelief/PropMessage.java
index dd6f2b1..029cd6a 100644
--- a/src/main/java/org/apache/horn/distbelief/PropMessage.java
+++ b/src/main/java/org/apache/horn/distbelief/PropMessage.java
@@ -37,6 +37,9 @@
this.weight = weight;
}
+ /**
+ * @return the activation or error message
+ */
public M getMessage() {
return message;
}
diff --git a/src/test/java/org/apache/horn/distbelief/TestDistBeliefModelTrainer.java b/src/test/java/org/apache/horn/distbelief/TestDistBeliefModelTrainer.java
new file mode 100644
index 0000000..5bbd90c
--- /dev/null
+++ b/src/test/java/org/apache/horn/distbelief/TestDistBeliefModelTrainer.java
@@ -0,0 +1,5 @@
+package org.apache.horn.distbelief;
+
+public class TestDistBeliefModelTrainer {
+
+}
diff --git a/src/test/java/org/apache/horn/distbelief/TestNeuron.java b/src/test/java/org/apache/horn/distbelief/TestNeuron.java
index 37e8fd6..9af1315 100644
--- a/src/test/java/org/apache/horn/distbelief/TestNeuron.java
+++ b/src/test/java/org/apache/horn/distbelief/TestNeuron.java
@@ -28,6 +28,8 @@
public class TestNeuron extends TestCase {
private static double learningRate = 0.1;
+ private static double bias = -1;
+ private static double theta = 0.8;
public static class MyNeuron extends
Neuron<PropMessage<DoubleWritable, DoubleWritable>> {
@@ -40,24 +42,24 @@
for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
sum += m.getMessage().get() * m.getWeight().get();
}
- sum += (-1 * 0.8);
+ sum += (bias * theta);
double output = new Sigmoid().apply(sum);
this.setOutput(output);
+ this.propagate(output);
}
@Override
public void downward(
Iterable<PropMessage<DoubleWritable, DoubleWritable>> messages)
throws IOException {
-
for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
// Calculates error gradient for each neuron
double gradient = this.getOutput() * (1 - this.getOutput())
* m.getMessage().get() * m.getWeight().get();
// Propagates to lower layer
- System.out.println(gradient);
+ this.propagate(gradient);
// Weight corrections
double weight = learningRate * this.getOutput() * m.getMessage().get();
@@ -84,4 +86,5 @@
n.downward(x);
assertEquals(-0.006688234848481696, n.getUpdate());
}
+
}