1. Separate master worker from slave workers. 2. Make master worker a dedicated Merger. 3. Fails when peer count < 2.
diff --git a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
index 132ec8c..002a9e5 100644
--- a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
@@ -17,7 +17,7 @@
*/
package org.apache.horn.bsp;
-import java.io.IOException;
+import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.LongWritable;
@@ -31,6 +31,8 @@
import org.apache.hama.commons.math.DoubleVector;
import org.mortbay.log.Log;
+import java.io.IOException;
+
/**
* The trainer that train the {@link SmallLayeredNeuralNetwork} based on BSP
* framework.
@@ -55,13 +57,26 @@
private String modelPath;
+ /**
+ * Returns true if this worker is master worker.
+ *
+ * @param peer
+ * */
+ private boolean isMaster(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
+ return peer.getPeerIndex() == 0;
+ }
+
@Override
/**
* If the model path is specified, load the existing from storage location.
*/
public void setup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
- if (peer.getPeerIndex() == 0) {
+ // At least one master & slave worker exist.
+ Preconditions.checkArgument(peer.getNumPeers() >= 2);
+
+ if (isMaster(peer)) {
Log.info("Begin to train");
}
this.isConverge = false;
@@ -84,7 +99,7 @@
public void cleanup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
// write model to modelPath
- if (peer.getPeerIndex() == 0) {
+ if (isMaster(peer)) {
try {
Log.info(String.format("End of training, number of iterations: %d.\n",
this.iterations));
@@ -102,12 +117,14 @@
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException, SyncException, InterruptedException {
while (this.iterations++ < maxIterations) {
- // each groom calculate the matrices updates according to local data
- calculateUpdates(peer);
+ // each slave-worker calculate the matrices updates according to local data
+ if (!isMaster(peer)) {
+ calculateUpdates(peer);
+ }
peer.sync();
// master merge the updates model
- if (peer.getPeerIndex() == 0) {
+ if (isMaster(peer)) {
mergeUpdates(peer);
}
peer.sync();
@@ -188,7 +205,7 @@
int numMessages = peer.getNumCurrentMessages();
boolean isConverge = false;
if (numMessages == 0) { // converges
- isConverge = true;
+ this.isConverge = true;
return;
}