Fix async merger This closes #13
Add license header
diff --git a/src/main/java/org/apache/horn/bsp/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/bsp/AbstractLayeredNeuralNetwork.java
index c29559d..b0d6ec5 100644
--- a/src/main/java/org/apache/horn/bsp/AbstractLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/bsp/AbstractLayeredNeuralNetwork.java
@@ -23,6 +23,7 @@
import java.util.List;
import org.apache.hadoop.io.WritableUtils;
+import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.math.DoubleDoubleFunction;
import org.apache.hama.commons.math.DoubleFunction;
import org.apache.hama.commons.math.DoubleMatrix;
@@ -81,8 +82,8 @@
this.learningStyle = LearningStyle.SUPERVISED;
}
- public AbstractLayeredNeuralNetwork(String modelPath) {
- super(modelPath);
+ public AbstractLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
+ super(conf, modelPath);
}
/**
diff --git a/src/main/java/org/apache/horn/bsp/AutoEncoder.java b/src/main/java/org/apache/horn/bsp/AutoEncoder.java
index 6c84dc2..135d63a 100644
--- a/src/main/java/org/apache/horn/bsp/AutoEncoder.java
+++ b/src/main/java/org/apache/horn/bsp/AutoEncoder.java
@@ -20,6 +20,7 @@
import java.util.Map;
import org.apache.hadoop.fs.Path;
+import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleFunction;
import org.apache.hama.commons.math.DoubleMatrix;
@@ -55,13 +56,14 @@
FunctionFactory.createDoubleFunction("Sigmoid"));
model.addLayer(inputDimensions, true,
FunctionFactory.createDoubleFunction("Sigmoid"));
- model.setLearningStyle(AbstractLayeredNeuralNetwork.LearningStyle.UNSUPERVISED);
+ model
+ .setLearningStyle(AbstractLayeredNeuralNetwork.LearningStyle.UNSUPERVISED);
model.setCostFunction(FunctionFactory
.createDoubleDoubleFunction("SquaredError"));
}
- public AutoEncoder(String modelPath) {
- model = new SmallLayeredNeuralNetwork(modelPath);
+ public AutoEncoder(HamaConfiguration conf, String modelPath) {
+ model = new SmallLayeredNeuralNetwork(conf, modelPath);
}
public AutoEncoder setLearningRate(double learningRate) {
@@ -78,7 +80,7 @@
model.setRegularizationWeight(regularizationWeight);
return this;
}
-
+
public AutoEncoder setModelPath(String modelPath) {
model.setModelPath(modelPath);
return this;
@@ -91,8 +93,9 @@
* @param dataInputPath
* @param trainingParams
*/
- public void train(Path dataInputPath, Map<String, String> trainingParams) {
- model.train(dataInputPath, trainingParams);
+ public void train(HamaConfiguration conf, Path dataInputPath,
+ Map<String, String> trainingParams) {
+ model.train(conf, dataInputPath, trainingParams);
}
/**
@@ -129,13 +132,13 @@
* @return The compressed information.
*/
private DoubleVector transform(DoubleVector inputInstance, int inputLayer) {
- DoubleVector internalInstance = new DenseDoubleVector(inputInstance.getDimension() + 1);
+ DoubleVector internalInstance = new DenseDoubleVector(
+ inputInstance.getDimension() + 1);
internalInstance.set(0, 1);
for (int i = 0; i < inputInstance.getDimension(); ++i) {
internalInstance.set(i + 1, inputInstance.get(i));
}
- DoubleFunction squashingFunction = model
- .getSquashingFunction(inputLayer);
+ DoubleFunction squashingFunction = model.getSquashingFunction(inputLayer);
DoubleMatrix weightMatrix = null;
if (inputLayer == 0) {
weightMatrix = this.getEncodeWeightMatrix();
@@ -149,6 +152,7 @@
/**
* Encode the input instance.
+ *
* @param inputInstance
* @return a new vector with the encode input instance.
*/
@@ -156,13 +160,16 @@
Preconditions
.checkArgument(
inputInstance.getDimension() == model.getLayerSize(0) - 1,
- String.format("The dimension of input instance is %d, but the model requires dimension %d.",
+ String
+ .format(
+ "The dimension of input instance is %d, but the model requires dimension %d.",
inputInstance.getDimension(), model.getLayerSize(1) - 1));
return this.transform(inputInstance, 0);
}
/**
* Decode the input instance.
+ *
* @param inputInstance
* @return a new vector with the decode input instance.
*/
@@ -170,22 +177,27 @@
Preconditions
.checkArgument(
inputInstance.getDimension() == model.getLayerSize(1) - 1,
- String.format("The dimension of input instance is %d, but the model requires dimension %d.",
+ String
+ .format(
+ "The dimension of input instance is %d, but the model requires dimension %d.",
inputInstance.getDimension(), model.getLayerSize(1) - 1));
return this.transform(inputInstance, 1);
}
-
+
/**
* Get the label(s) according to the given features.
+ *
* @param inputInstance
- * @return a new vector with output of the model according to given feature instance.
+ * @return a new vector with output of the model according to given feature
+ * instance.
*/
public DoubleVector getOutput(DoubleVector inputInstance) {
return model.getOutput(inputInstance);
}
-
+
/**
* Set the feature transformer.
+ *
* @param featureTransformer
*/
public void setFeatureTransformer(FeatureTransformer featureTransformer) {
diff --git a/src/main/java/org/apache/horn/bsp/HornJob.java b/src/main/java/org/apache/horn/bsp/HornJob.java
new file mode 100644
index 0000000..bc79f54
--- /dev/null
+++ b/src/main/java/org/apache/horn/bsp/HornJob.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.bsp;
+
+import java.io.IOException;
+
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.commons.math.Function;
+import org.apache.horn.trainer.Neuron;
+import org.apache.horn.trainer.Trainer;
+
+public class HornJob extends BSPJob {
+
+ public HornJob(HamaConfiguration conf, Class<?> exampleClass)
+ throws IOException {
+ super(conf);
+ this.setBspClass(Trainer.class);
+ this.setJarByClass(exampleClass);
+ }
+
+ public void setDouble(String name, double value) {
+ conf.setDouble(name, value);
+ }
+
+ @SuppressWarnings("rawtypes")
+ public void addLayer(int i, Class<? extends Neuron> class1,
+ Class<? extends Function> class2) {
+ // TODO Auto-generated method stub
+
+ }
+
+ public void setCostFunction(Class<? extends Function> class1) {
+ // TODO Auto-generated method stub
+
+ }
+
+ public void setMaxIteration(int n) {
+ this.conf.setInt("horn.max.iteration", n);
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/bsp/NeuralNetwork.java b/src/main/java/org/apache/horn/bsp/NeuralNetwork.java
index c7f14de..5afe1d3 100644
--- a/src/main/java/org/apache/horn/bsp/NeuralNetwork.java
+++ b/src/main/java/org/apache/horn/bsp/NeuralNetwork.java
@@ -17,27 +17,26 @@
*/
package org.apache.horn.bsp;
-import com.google.common.base.Preconditions;
-import com.google.common.io.Closeables;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Map;
+
import org.apache.commons.lang.SerializationUtils;
-import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
+import org.apache.hama.HamaConfiguration;
import org.apache.hama.ml.util.DefaultFeatureTransformer;
import org.apache.hama.ml.util.FeatureTransformer;
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
-import java.lang.reflect.Constructor;
-import java.lang.reflect.InvocationTargetException;
-import java.net.URI;
-import java.net.URISyntaxException;
-import java.util.Map;
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
/**
* NeuralNetwork defines the general operations for all the derivative models.
@@ -47,6 +46,8 @@
*
*/
abstract class NeuralNetwork implements Writable {
+ protected HamaConfiguration conf;
+ protected FileSystem fs;
private static final double DEFAULT_LEARNING_RATE = 0.5;
@@ -67,12 +68,19 @@
}
public NeuralNetwork(String modelPath) {
+ }
+
+ public NeuralNetwork(HamaConfiguration conf, String modelPath) {
try {
+ this.conf = conf;
+ this.fs = FileSystem.get(conf);
this.modelPath = modelPath;
+
this.readFromModel();
} catch (IOException e) {
e.printStackTrace();
}
+
}
/**
@@ -107,12 +115,12 @@
* @param trainingParams The parameters for training.
* @throws IOException
*/
- public void train(Path dataInputPath, Map<String, String> trainingParams) {
+ public void train(HamaConfiguration hamaConf, Path dataInputPath, Map<String, String> trainingParams) {
Preconditions.checkArgument(this.modelPath != null,
"Please set the model path before training.");
// train with BSP job
try {
- trainInternal(dataInputPath, trainingParams);
+ trainInternal(hamaConf, dataInputPath, trainingParams);
// write the trained model back to model path
this.readFromModel();
} catch (IOException e) {
@@ -130,9 +138,9 @@
* @param dataInputPath
* @param trainingParams
*/
- protected abstract void trainInternal(Path dataInputPath,
- Map<String, String> trainingParams) throws IOException,
- InterruptedException, ClassNotFoundException;
+ protected abstract void trainInternal(HamaConfiguration hamaConf,
+ Path dataInputPath, Map<String, String> trainingParams)
+ throws IOException, InterruptedException, ClassNotFoundException;
/**
* Read the model meta-data from the specified location.
@@ -142,18 +150,9 @@
protected void readFromModel() throws IOException {
Preconditions.checkArgument(this.modelPath != null,
"Model path has not been set.");
- Configuration conf = new Configuration();
- FSDataInputStream is = null;
- try {
- URI uri = new URI(this.modelPath);
- FileSystem fs = FileSystem.get(uri, conf);
- is = new FSDataInputStream(fs.open(new Path(modelPath)));
- this.readFields(is);
- } catch (URISyntaxException e) {
- e.printStackTrace();
- } finally {
- Closeables.close(is, false);
- }
+ FSDataInputStream is = new FSDataInputStream(fs.open(new Path(modelPath)));
+ this.readFields(is);
+ Closeables.close(is, false);
}
/**
@@ -164,16 +163,9 @@
public void writeModelToFile() throws IOException {
Preconditions.checkArgument(this.modelPath != null,
"Model path has not been set.");
- Configuration conf = new Configuration();
- FSDataOutputStream is = null;
- try {
- URI uri = new URI(this.modelPath);
- FileSystem fs = FileSystem.get(uri, conf);
- is = fs.create(new Path(this.modelPath), true);
- this.write(is);
- } catch (URISyntaxException e) {
- e.printStackTrace();
- }
+
+ FSDataOutputStream is = fs.create(new Path(this.modelPath), true);
+ this.write(is);
Closeables.close(is, false);
}
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMerger.java b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
index 709331b..6df719a 100644
--- a/src/main/java/org/apache/horn/bsp/ParameterMerger.java
+++ b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
@@ -1,10 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package org.apache.horn.bsp;
-import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.ipc.VersionedProtocol;
public interface ParameterMerger extends VersionedProtocol {
- long versionID = 1L;
+ long versionID = 1L;
- SmallLayeredNeuralNetworkMessage merge(double trainingError, DoubleMatrix[] weightUpdates, DoubleMatrix[] prevWeightUpdates);
+ SmallLayeredNeuralNetworkMessage merge(SmallLayeredNeuralNetworkMessage msg);
+
}
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
index 54caf2b..47aab84 100644
--- a/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
+++ b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
@@ -1,97 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package org.apache.horn.bsp;
-import com.google.common.base.Preconditions;
-
-import org.apache.hama.commons.math.DoubleMatrix;
-import org.mortbay.log.Log;
-
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hama.commons.math.DoubleMatrix;
+
+import com.google.common.base.Preconditions;
+
public class ParameterMergerServer implements ParameterMerger {
- /* The parameter merge base. */
- protected SmallLayeredNeuralNetwork inMemoryModel;
- /* To terminate or not to terminate. */
- protected AtomicBoolean isConverge;
+ private static final Log LOG = LogFactory.getLog(ParameterMergerServer.class);
- /* The number of slave works that request commits. */
- protected int SlaveCount;
+ /* The parameter merge base. */
+ protected SmallLayeredNeuralNetwork inMemoryModel;
- /* After mergeLimit, terminate whether the result is converging or not. */
- protected int mergeLimit;
+ /* To terminate or not to terminate. */
+ protected AtomicBoolean isConverge;
- /* last n training errors. converging is decided based on the average value of these errors. */
- protected double[] trainingErrors;
+ /* The number of slave works that request commits. */
+ protected int SlaveCount;
- /* If the average of last n training errors is smaller than this value, it is converging. */
- protected double prevAvgTrainingError = Double.MAX_VALUE;
+ /* After mergeLimit, terminate whether the result is converging or not. */
+ protected int mergeLimit;
- /* current index for trainingErrors. */
- protected int curTrainingError = 0;
+ /*
+ * last n training errors. converging is decided based on the average value of
+ * these errors.
+ */
+ protected double[] trainingErrors;
- /* how many merges have been conducted? */
- protected int mergeCount = 0;
+ /*
+ * If the average of last n training errors is smaller than this value, it is
+ * converging.
+ */
+ protected double prevAvgTrainingError = Double.MAX_VALUE;
- public ParameterMergerServer(SmallLayeredNeuralNetwork inMemoryModel, AtomicBoolean isConverge,
- int slaveCount, int mergeLimit, int convergenceCheckInterval) {
- this.inMemoryModel = inMemoryModel;
- this.isConverge = isConverge;
- this.SlaveCount = slaveCount;
- this.mergeLimit = mergeLimit;
- this.trainingErrors = new double[convergenceCheckInterval];
- }
+ /* current index for trainingErrors. */
+ protected int curTrainingError = 0;
- @Override
- public long getProtocolVersion(String s, long l) throws IOException {
- return ParameterMerger.versionID;
- }
+ /* how many merges have been conducted? */
+ protected int mergeCount = 0;
- @Override
- public SmallLayeredNeuralNetworkMessage merge(double trainingError, DoubleMatrix[] weightUpdates,
- DoubleMatrix[] prevWeightUpdates) {
- Preconditions.checkArgument(weightUpdates.length == prevWeightUpdates.length);
+ public ParameterMergerServer(SmallLayeredNeuralNetwork inMemoryModel,
+ AtomicBoolean isConverge, int slaveCount, int mergeLimit,
+ int convergenceCheckInterval) {
+ this.inMemoryModel = inMemoryModel;
+ this.isConverge = isConverge;
+ this.SlaveCount = slaveCount;
+ this.mergeLimit = mergeLimit;
+ this.trainingErrors = new double[convergenceCheckInterval];
+ }
- Log.info(String.format("Start merging: %d.\n", this.mergeCount));
+ @Override
+ public long getProtocolVersion(String s, long l) throws IOException {
+ return ParameterMerger.versionID;
+ }
- if (!this.isConverge.get()) {
- for (int i = 0; i < weightUpdates.length; ++i) {
- weightUpdates[i] = weightUpdates[i].divide(this.SlaveCount);
- prevWeightUpdates[i] = prevWeightUpdates[i].divide(this.SlaveCount);
- }
+ @Override
+ public SmallLayeredNeuralNetworkMessage merge(
+ SmallLayeredNeuralNetworkMessage msg) {
- synchronized (inMemoryModel) {
- this.inMemoryModel.updateWeightMatrices(weightUpdates);
- this.inMemoryModel.setPrevWeightMatrices(prevWeightUpdates);
+ double trainingError = msg.getTrainingError();
+ DoubleMatrix[] weightUpdates = msg.getCurMatrices();
+ DoubleMatrix[] prevWeightUpdates = msg.getPrevMatrices();
- // add trainingError to trainingErrors
- this.trainingErrors[this.curTrainingError++] = trainingError;
+ Preconditions
+ .checkArgument(weightUpdates.length == prevWeightUpdates.length);
- // check convergence
- if (this.trainingErrors.length == this.curTrainingError) {
- double curAvgTrainingError = 0.0;
- for (int i = 0; i < this.curTrainingError; ++i) {
- curAvgTrainingError += this.trainingErrors[i];
- }
- curAvgTrainingError /= this.trainingErrors.length;
+ LOG.info("Start merging: " + this.mergeCount);
- if (prevAvgTrainingError < curAvgTrainingError) {
- this.isConverge.set(true);
- } else {
- // update
- prevAvgTrainingError = curAvgTrainingError;
- this.curTrainingError = 0;
- }
- }
+ if (!this.isConverge.get()) {
+ for (int i = 0; i < weightUpdates.length; ++i) {
+ weightUpdates[i] = weightUpdates[i].divide(this.SlaveCount);
+ prevWeightUpdates[i] = prevWeightUpdates[i].divide(this.SlaveCount);
+ }
- if (++this.mergeCount == this.mergeLimit) {
- this.isConverge.set(true);
- }
- }
- }
+ synchronized (inMemoryModel) {
+ this.inMemoryModel.updateWeightMatrices(weightUpdates);
+ this.inMemoryModel.setPrevWeightMatrices(prevWeightUpdates);
- return new SmallLayeredNeuralNetworkMessage(
- 0, this.isConverge.get(), this.inMemoryModel.getWeightMatrices(),
- this.inMemoryModel.getPrevMatricesUpdates());
- }
+ // add trainingError to trainingErrors
+ this.trainingErrors[this.curTrainingError++] = trainingError;
+
+ // check convergence
+ if (this.trainingErrors.length == this.curTrainingError) {
+ double curAvgTrainingError = 0.0;
+ for (int i = 0; i < this.curTrainingError; ++i) {
+ curAvgTrainingError += this.trainingErrors[i];
+ }
+ curAvgTrainingError /= this.trainingErrors.length;
+
+ if (prevAvgTrainingError < curAvgTrainingError) {
+ this.isConverge.set(true);
+ } else {
+ // update
+ prevAvgTrainingError = curAvgTrainingError;
+ this.curTrainingError = 0;
+ }
+ }
+
+ if (++this.mergeCount == this.mergeLimit) {
+ this.isConverge.set(true);
+ }
+ }
+ }
+
+ return new SmallLayeredNeuralNetworkMessage(0, this.isConverge.get(),
+ this.inMemoryModel.getWeightMatrices(),
+ this.inMemoryModel.getPrevMatricesUpdates());
+ }
+
}
diff --git a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetwork.java
index 4aee4ce..bd0d103 100644
--- a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetwork.java
@@ -26,7 +26,9 @@
import java.util.Map;
import org.apache.commons.lang.math.RandomUtils;
-import org.apache.hadoop.conf.Configuration;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
@@ -41,7 +43,6 @@
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
-import org.mortbay.log.Log;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@@ -60,6 +61,9 @@
*/
public class SmallLayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
+ private static final Log LOG = LogFactory
+ .getLog(SmallLayeredNeuralNetwork.class);
+
/* Weights between neurons at adjacent layers */
protected List<DoubleMatrix> weightMatrixList;
@@ -78,8 +82,8 @@
this.squashingFunctionList = Lists.newArrayList();
}
- public SmallLayeredNeuralNetwork(String modelPath) {
- super(modelPath);
+ public SmallLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
+ super(conf, modelPath);
}
@Override
@@ -94,6 +98,7 @@
size += 1;
}
+ LOG.info("Add Layer: " + size);
this.layerSizeList.add(size);
int layerIdx = this.layerSizeList.size() - 1;
if (isFinalLayer) {
@@ -497,11 +502,13 @@
}
@Override
- protected void trainInternal(Path dataInputPath,
+ protected void trainInternal(HamaConfiguration hamaConf, Path dataInputPath,
Map<String, String> trainingParams) throws IOException,
InterruptedException, ClassNotFoundException {
// add all training parameters to configuration
- Configuration conf = new Configuration();
+ this.conf = hamaConf;
+ this.fs = FileSystem.get(conf);
+
for (Map.Entry<String, String> entry : trainingParams.entrySet()) {
conf.set(entry.getKey(), entry.getValue());
}
@@ -521,10 +528,8 @@
conf.set("modelPath", this.modelPath);
this.writeModelToFile();
- HamaConfiguration hamaConf = new HamaConfiguration(conf);
-
// create job
- BSPJob job = new BSPJob(hamaConf, SmallLayeredNeuralNetworkTrainer.class);
+ BSPJob job = new BSPJob(conf, SmallLayeredNeuralNetworkTrainer.class);
job.setJobName("Small scale Neural Network training");
job.setJarByClass(SmallLayeredNeuralNetworkTrainer.class);
job.setBspClass(SmallLayeredNeuralNetworkTrainer.class);
@@ -537,12 +542,12 @@
job.setOutputFormat(org.apache.hama.bsp.NullOutputFormat.class);
int numTasks = conf.getInt("tasks", 1);
- Log.info(String.format("Number of tasks: %d\n", numTasks));
+ LOG.info(String.format("Number of tasks: %d\n", numTasks));
job.setNumBspTask(numTasks);
job.waitForCompletion(true);
// reload learned model
- Log.info(String.format("Reload model from %s.", this.modelPath));
+ LOG.info(String.format("Reload model from %s.", this.modelPath));
this.readFromModel();
}
diff --git a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
index 9e3d02f..696d56c 100644
--- a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
@@ -17,11 +17,15 @@
*/
package org.apache.horn.bsp;
-import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.concurrent.atomic.AtomicBoolean;
-import org.apache.hadoop.conf.Configuration;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
+import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
@@ -30,11 +34,8 @@
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.ipc.RPC;
-import org.mortbay.log.Log;
-import java.io.IOException;
-import java.net.InetSocketAddress;
-import java.util.concurrent.atomic.AtomicBoolean;
+import com.google.common.base.Preconditions;
/**
* The trainer that train the {@link SmallLayeredNeuralNetwork} based on BSP
@@ -44,12 +45,16 @@
public final class SmallLayeredNeuralNetworkTrainer
extends
BSP<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> {
+
+ private static final Log LOG = LogFactory
+ .getLog(SmallLayeredNeuralNetworkTrainer.class);
+
/* When given peer is master worker: base of parameter merge */
/* When given peer is slave worker: neural network for training */
private SmallLayeredNeuralNetwork inMemoryModel;
/* Job configuration */
- private Configuration conf;
+ private HamaConfiguration conf;
/* Default batch size */
private int batchSize;
@@ -72,7 +77,7 @@
* */
private boolean isMaster(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
- return peer.getPeerIndex() == 0;
+ return peer.getPeerIndex() == peer.getNumPeers() - 1;
}
@Override
@@ -84,33 +89,37 @@
// At least one master & slave worker exist.
Preconditions.checkArgument(peer.getNumPeers() >= 2);
- String modelPath = conf.get("modelPath");
- this.inMemoryModel = new SmallLayeredNeuralNetwork(modelPath);
this.conf = peer.getConfiguration();
+
+ String modelPath = conf.get("modelPath");
+ this.inMemoryModel = new SmallLayeredNeuralNetwork(conf, modelPath);
+
this.batchSize = conf.getInt("training.batch.size", 50);
this.isConverge = new AtomicBoolean(false);
int slaveCount = peer.getNumPeers() - 1;
int mergeLimit = conf.getInt("training.max.iterations", 100000);
- int convergenceCheckInterval = peer.getNumPeers() * conf.getInt("convergence.check.interval",
- 2000);
+ int convergenceCheckInterval = peer.getNumPeers()
+ * conf.getInt("convergence.check.interval", 2000);
String master = peer.getPeerName();
String masterAddr = master.substring(0, master.indexOf(':'));
- int port = conf.getInt("sync.server.port", 40042);
+ int port = conf.getInt("sync.server.port", 40052);
if (isMaster(peer)) {
try {
- this.merger = RPC.getServer(new ParameterMergerServer(inMemoryModel, isConverge, slaveCount,
- mergeLimit, convergenceCheckInterval), masterAddr, port, conf);
+ this.merger = RPC.getServer(new ParameterMergerServer(inMemoryModel,
+ isConverge, slaveCount, mergeLimit, convergenceCheckInterval),
+ masterAddr, port, conf);
merger.start();
} catch (IOException e) {
e.printStackTrace();
}
- Log.info("Begin to train");
+ LOG.info("Begin to train");
} else {
InetSocketAddress addr = new InetSocketAddress(masterAddr, port);
try {
- this.proxy = (ParameterMerger) RPC.getProxy(ParameterMerger.class, ParameterMerger.versionID, addr, conf);
+ this.proxy = (ParameterMerger) RPC.getProxy(ParameterMerger.class,
+ ParameterMerger.versionID, addr, conf);
} catch (IOException e) {
e.printStackTrace();
}
@@ -126,7 +135,7 @@
// write model to modelPath
if (isMaster(peer)) {
try {
- Log.info(String.format("Write model back to %s\n",
+ LOG.info(String.format("Write model back to %s\n",
inMemoryModel.getModelPath()));
this.inMemoryModel.writeModelToFile();
} catch (IOException e) {
@@ -139,13 +148,19 @@
public void bsp(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException, SyncException, InterruptedException {
- if (!isMaster(peer)) {
- while (!this.isConverge.get()) {
- // each slave-worker calculate the matrices updates according to local data
- // and merge them with master
+ while (!this.isConverge.get()) {
+ // each slave-worker calculate the matrices updates according to local
+ // data
+ // and merge them with master
+ if (!isMaster(peer)) {
calculateUpdates(peer);
}
}
+
+ if (isMaster(peer)) {
+ merger.stop();
+ }
+ peer.sync(); // finalize the bsp program.
}
/**
@@ -157,6 +172,7 @@
private void calculateUpdates(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException {
+
DoubleMatrix[] weightUpdates = new DoubleMatrix[this.inMemoryModel.weightMatrixList
.size()];
for (int i = 0; i < weightUpdates.length; ++i) {
@@ -187,8 +203,11 @@
}
// exchange parameter update with master
- SmallLayeredNeuralNetworkMessage inMessage = proxy.merge(avgTrainingError, weightUpdates,
- this.inMemoryModel.getWeightMatrices());
+ SmallLayeredNeuralNetworkMessage msg = new SmallLayeredNeuralNetworkMessage(
+ avgTrainingError, false, weightUpdates,
+ this.inMemoryModel.getPrevMatricesUpdates());
+
+ SmallLayeredNeuralNetworkMessage inMessage = proxy.merge(msg);
DoubleMatrix[] newWeights = inMessage.getCurMatrices();
DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
this.inMemoryModel.setWeightMatrices(newWeights);
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
new file mode 100644
index 0000000..26402cc
--- /dev/null
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -0,0 +1,112 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.examples;
+
+import java.io.IOException;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.TextInputFormat;
+import org.apache.horn.bsp.HornJob;
+import org.apache.horn.funcs.CrossEntropy;
+import org.apache.horn.funcs.Sigmoid;
+import org.apache.horn.trainer.Neuron;
+import org.apache.horn.trainer.PropMessage;
+
+public class MultiLayerPerceptron {
+
+ public static class StandardNeuron extends
+ Neuron<PropMessage<DoubleWritable, DoubleWritable>> {
+
+ private double learningRate;
+ private double lambda;
+ private double momentum;
+ private static double bias = -1;
+
+ @Override
+ public void setup(HamaConfiguration conf) {
+ this.learningRate = conf.getDouble("mlp.learning.rate", 0.1);
+ this.lambda = conf.getDouble("mlp.regularization.weight", 0.01);
+ this.momentum = conf.getDouble("mlp.momentum.weight", 0.2);
+ }
+
+ @Override
+ public void forward(
+ Iterable<PropMessage<DoubleWritable, DoubleWritable>> messages)
+ throws IOException {
+ double sum = 0;
+
+ for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
+ sum += m.getInput() * m.getWeight();
+ }
+ sum += bias * this.getTheta(); // add bias feature
+ feedforward(activation(sum));
+ }
+
+ @Override
+ public void backward(
+ Iterable<PropMessage<DoubleWritable, DoubleWritable>> messages)
+ throws IOException {
+ for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
+ // Calculates error gradient for each neuron
+ double gradient = this.getOutput() * (1 - this.getOutput())
+ * m.getDelta() * m.getWeight();
+ backpropagate(gradient);
+
+ // Weight corrections
+ double weight = -learningRate * this.getOutput() * m.getDelta()
+ + momentum * this.getPreviousWeight();
+ this.push(weight);
+ }
+ }
+
+ }
+
+ public static void main(String[] args) throws IOException,
+ InterruptedException, ClassNotFoundException {
+ HamaConfiguration conf = new HamaConfiguration();
+ HornJob job = new HornJob(conf, MultiLayerPerceptron.class);
+
+ job.setDouble("mlp.learning.rate", 0.1);
+ job.setDouble("mlp.regularization.weight", 0.01);
+ job.setDouble("mlp.momentum.weight", 0.2);
+
+ // initialize the topology of the model.
+ // a three-layer model is created in this example
+ job.addLayer(1000, StandardNeuron.class, Sigmoid.class); // 1st layer
+ job.addLayer(800, StandardNeuron.class, Sigmoid.class); // 2nd layer
+ job.addLayer(300, StandardNeuron.class, Sigmoid.class); // total classes
+
+ // set the cost function to evaluate the error
+ job.setCostFunction(CrossEntropy.class);
+
+ // set I/O and others
+ job.setInputFormat(TextInputFormat.class);
+ job.setOutputPath(new Path("/tmp/"));
+ job.setMaxIteration(10000);
+ job.setNumBspTask(3);
+
+ long startTime = System.currentTimeMillis();
+
+ if (job.waitForCompletion(true)) {
+ System.out.println("Job Finished in "
+ + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
+ }
+ }
+}
diff --git a/src/main/java/org/apache/horn/examples/NeuralNetwork.java b/src/main/java/org/apache/horn/examples/NeuralNetwork.java
index c056b2f..737412b 100644
--- a/src/main/java/org/apache/horn/examples/NeuralNetwork.java
+++ b/src/main/java/org/apache/horn/examples/NeuralNetwork.java
@@ -44,19 +44,20 @@
printUsage();
return;
}
+ HamaConfiguration conf = new HamaConfiguration();
String mode = args[0];
+
if (mode.equalsIgnoreCase("label")) {
if (args.length < 4) {
printUsage();
return;
}
- HamaConfiguration conf = new HamaConfiguration();
String featureDataPath = args[1];
String resultDataPath = args[2];
String modelPath = args[3];
- SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(modelPath);
+ SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(conf, modelPath);
// process data in streaming approach
FileSystem fs = FileSystem.get(new URI(featureDataPath), conf);
@@ -173,11 +174,11 @@
ann.setModelPath(trainedModelPath);
Map<String, String> trainingParameters = new HashMap<String, String>();
- trainingParameters.put("tasks", "5");
+ trainingParameters.put("tasks", "2");
trainingParameters.put("training.max.iterations", "" + iteration);
trainingParameters.put("training.batch.size", "300");
trainingParameters.put("convergence.check.interval", "1000");
- ann.train(new Path(trainingDataPath), trainingParameters);
+ ann.train(conf, new Path(trainingDataPath), trainingParameters);
}
}
diff --git a/src/main/java/org/apache/horn/trainer/Neuron.java b/src/main/java/org/apache/horn/trainer/Neuron.java
index 1ae473b..d1c35d1 100644
--- a/src/main/java/org/apache/horn/trainer/Neuron.java
+++ b/src/main/java/org/apache/horn/trainer/Neuron.java
@@ -18,13 +18,34 @@
package org.apache.horn.trainer;
import org.apache.hadoop.io.Writable;
+import org.apache.horn.funcs.Sigmoid;
public abstract class Neuron<M extends Writable> implements NeuronInterface<M> {
double output;
double weight;
- public void propagate(double gradient) {
+ /**
+ * @return the theta value of this neuron.
+ */
+ public double getTheta() {
// TODO Auto-generated method stub
+ return 0;
+ }
+
+ public void feedforward(double sum) {
+ // TODO Auto-generated method stub
+ // squashing
+ }
+
+ public void backpropagate(double gradient) {
+ // TODO Auto-generated method stub
+
+ }
+
+ public double activation(double sum) {
+ // TODO Auto-generated method stub
+ this.output = new Sigmoid().apply(sum);
+ return output;
}
public void setOutput(double output) {
@@ -35,6 +56,12 @@
return output;
}
+ // ////////* Below methods will communicate with parameter server */
+
+ public double getPreviousWeight() {
+ return weight;
+ }
+
public void push(double weight) {
// TODO Auto-generated method stub
this.weight = weight;
diff --git a/src/main/java/org/apache/horn/trainer/NeuronInterface.java b/src/main/java/org/apache/horn/trainer/NeuronInterface.java
index 4921c15..c96931e 100644
--- a/src/main/java/org/apache/horn/trainer/NeuronInterface.java
+++ b/src/main/java/org/apache/horn/trainer/NeuronInterface.java
@@ -20,9 +20,12 @@
import java.io.IOException;
import org.apache.hadoop.io.Writable;
+import org.apache.hama.HamaConfiguration;
public interface NeuronInterface<M extends Writable> {
+ public void setup(HamaConfiguration conf);
+
/**
* This method is called when the messages are propagated from the lower
* layer. It can be used to determine if the neuron would activate, or fire.
@@ -30,7 +33,7 @@
* @param messages
* @throws IOException
*/
- public void upward(Iterable<M> messages) throws IOException;
+ public void forward(Iterable<M> messages) throws IOException;
/**
* This method is called when the errors are propagated from the upper layer.
@@ -40,6 +43,6 @@
* @param messages
* @throws IOException
*/
- public void downward(Iterable<M> messages) throws IOException;
+ public void backward(Iterable<M> messages) throws IOException;
}
diff --git a/src/main/java/org/apache/horn/trainer/PropMessage.java b/src/main/java/org/apache/horn/trainer/PropMessage.java
index 74b2434..5724943 100644
--- a/src/main/java/org/apache/horn/trainer/PropMessage.java
+++ b/src/main/java/org/apache/horn/trainer/PropMessage.java
@@ -21,6 +21,7 @@
import java.io.DataOutput;
import java.io.IOException;
+import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.Writable;
/**
@@ -29,10 +30,10 @@
public class PropMessage<M extends Writable, W extends Writable> implements
Writable {
- M message;
- W weight;
+ DoubleWritable message;
+ DoubleWritable weight;
- public PropMessage(M message, W weight) {
+ public PropMessage(DoubleWritable message, DoubleWritable weight) {
this.message = message;
this.weight = weight;
}
@@ -40,12 +41,22 @@
/**
* @return the activation or error message
*/
- public M getMessage() {
- return message;
+ public double getMessage() {
+ return message.get();
}
- public W getWeight() {
- return weight;
+ public double getInput() {
+ // returns the input
+ return message.get();
+ }
+
+ public double getDelta() {
+ // returns the delta
+ return message.get();
+ }
+
+ public double getWeight() {
+ return weight.get();
}
@Override
@@ -60,4 +71,4 @@
weight.write(out);
}
-}
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/horn/trainer/Trainer.java b/src/main/java/org/apache/horn/trainer/Trainer.java
index 94309c9..4f903f0 100644
--- a/src/main/java/org/apache/horn/trainer/Trainer.java
+++ b/src/main/java/org/apache/horn/trainer/Trainer.java
@@ -53,9 +53,6 @@
this.iterations = 0;
this.maxIterations = peer.getConfiguration()
.getInt("horn.max.iteration", 1);
- this.batchSize = peer.getConfiguration()
- .getInt("horn.minibatch.size", 1000);
-
LOG.info("max iteration: " + this.maxIterations);
// loads subset of neural network model replica into memory
@@ -71,7 +68,7 @@
// Fetch latest parameters
fetchParameters(peer);
// Perform the batch
- doMinibatch(peer);
+ performBatch(peer);
// Push parameters
pushParameters(peer);
@@ -90,7 +87,7 @@
* @throws InterruptedException
* @throws SyncException
*/
- private void doMinibatch(BSPPeer peer) throws IOException, SyncException, InterruptedException {
+ private void performBatch(BSPPeer peer) throws IOException, SyncException, InterruptedException {
double avgTrainingError = 0.0;
int trains = 0;
diff --git a/src/test/java/org/apache/horn/bsp/TestAutoEncoder.java b/src/test/java/org/apache/horn/bsp/TestAutoEncoder.java
index 0aaa926..9d5c0b9 100644
--- a/src/test/java/org/apache/horn/bsp/TestAutoEncoder.java
+++ b/src/test/java/org/apache/horn/bsp/TestAutoEncoder.java
@@ -31,11 +31,11 @@
import java.util.Map;
import java.util.Random;
-import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
+import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleVector;
@@ -55,12 +55,13 @@
AutoEncoder encoder = new AutoEncoder(4, 2);
encoder.setLearningRate(0.5);
encoder.setMomemtumWeight(0.2);
-
+
int maxIteration = 2000;
Random rnd = new Random();
for (int iteration = 0; iteration < maxIteration; ++iteration) {
for (int i = 0; i < instances.length; ++i) {
- encoder.trainOnline(new DenseDoubleVector(instances[rnd.nextInt(instances.length)]));
+ encoder.trainOnline(new DenseDoubleVector(instances[rnd
+ .nextInt(instances.length)]));
}
}
@@ -74,12 +75,13 @@
}
}
-
+
@Test
public void testAutoEncoderSwissRollDataset() {
List<double[]> instanceList = new ArrayList<double[]>();
try {
- BufferedReader br = new BufferedReader(new FileReader("src/test/resources/dimensional_reduction.txt"));
+ BufferedReader br = new BufferedReader(new FileReader(
+ "src/test/resources/dimensional_reduction.txt"));
String line = null;
while ((line = br.readLine()) != null) {
String[] tokens = line.split("\t");
@@ -99,7 +101,7 @@
} catch (IOException e) {
e.printStackTrace();
}
-
+
List<DoubleVector> vecInstanceList = new ArrayList<DoubleVector>();
for (double[] instance : instanceList) {
vecInstanceList.add(new DenseDoubleVector(instance));
@@ -123,24 +125,26 @@
++errorInstance;
}
}
- Log.info(String.format("Autoecoder error rate: %f%%\n", errorInstance * 100 / vecInstanceList.size()));
-
+ Log.info(String.format("Autoecoder error rate: %f%%\n", errorInstance * 100
+ / vecInstanceList.size()));
+
}
-
+
@Test
public void testAutoEncoderSwissRollDatasetDistributed() {
+ HamaConfiguration conf = new HamaConfiguration();
String strDataPath = "/tmp/dimensional_reduction.txt";
Path path = new Path(strDataPath);
List<double[]> instanceList = new ArrayList<double[]>();
try {
- Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(new URI(strDataPath), conf);
if (fs.exists(path)) {
fs.delete(path, true);
}
-
+
String line = null;
- BufferedReader br = new BufferedReader(new FileReader("src/test/resources/dimensional_reduction.txt"));
+ BufferedReader br = new BufferedReader(new FileReader(
+ "src/test/resources/dimensional_reduction.txt"));
while ((line = br.readLine()) != null) {
String[] tokens = line.split("\t");
double[] instance = new double[tokens.length];
@@ -152,13 +156,14 @@
br.close();
// normalize instances
zeroOneNormalization(instanceList, instanceList.get(0).length);
-
- SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, LongWritable.class, VectorWritable.class);
+
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path,
+ LongWritable.class, VectorWritable.class);
for (int i = 0; i < instanceList.size(); ++i) {
DoubleVector vector = new DenseDoubleVector(instanceList.get(i));
writer.append(new LongWritable(i), new VectorWritable(vector));
}
-
+
writer.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
@@ -167,7 +172,7 @@
} catch (URISyntaxException e) {
e.printStackTrace();
}
-
+
AutoEncoder encoder = new AutoEncoder(3, 2);
String modelPath = "/tmp/autoencoder-modelpath";
encoder.setModelPath(modelPath);
@@ -176,8 +181,8 @@
trainingParams.put("tasks", "5");
trainingParams.put("training.max.iterations", "3000");
trainingParams.put("training.batch.size", "200");
- encoder.train(path, trainingParams);
-
+ encoder.train(conf, path, trainingParams);
+
double errorInstance = 0;
for (double[] instance : instanceList) {
DoubleVector vector = new DenseDoubleVector(instance);
@@ -188,7 +193,8 @@
++errorInstance;
}
}
- Log.info(String.format("Autoecoder error rate: %f%%\n", errorInstance * 100 / instanceList.size()));
+ Log.info(String.format("Autoecoder error rate: %f%%\n", errorInstance * 100
+ / instanceList.size()));
}
}
diff --git a/src/test/java/org/apache/horn/bsp/TestSmallLayeredNeuralNetwork.java b/src/test/java/org/apache/horn/bsp/TestSmallLayeredNeuralNetwork.java
index 85c4b7a..2f3a5b2 100644
--- a/src/test/java/org/apache/horn/bsp/TestSmallLayeredNeuralNetwork.java
+++ b/src/test/java/org/apache/horn/bsp/TestSmallLayeredNeuralNetwork.java
@@ -38,16 +38,17 @@
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
+import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
-import org.apache.horn.bsp.AbstractLayeredNeuralNetwork.LearningStyle;
-import org.apache.horn.bsp.AbstractLayeredNeuralNetwork.TrainingMethod;
import org.apache.hama.ml.util.DefaultFeatureTransformer;
import org.apache.hama.ml.util.FeatureTransformer;
+import org.apache.horn.bsp.AbstractLayeredNeuralNetwork.LearningStyle;
+import org.apache.horn.bsp.AbstractLayeredNeuralNetwork.TrainingMethod;
import org.junit.Test;
import org.mortbay.log.Log;
@@ -95,7 +96,7 @@
}
// read from file
- SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(modelPath);
+ SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(new HamaConfiguration(), modelPath);
assertEquals(annCopy.getClass().getSimpleName(), annCopy.getModelType());
assertEquals(modelPath, annCopy.getModelPath());
assertEquals(learningRate, annCopy.getLearningRate(), 0.000001);
@@ -227,7 +228,7 @@
} catch (IOException e) {
e.printStackTrace();
}
- SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(modelPath);
+ SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(new HamaConfiguration(), modelPath);
// test on instances
for (int i = 0; i < instances.length; ++i) {
DoubleVector input = new DenseDoubleVector(instances[i]).slice(2);
@@ -277,7 +278,7 @@
} catch (IOException e) {
e.printStackTrace();
}
- SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(modelPath);
+ SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(new HamaConfiguration(), modelPath);
// test on instances
for (int i = 0; i < instances.length; ++i) {
DoubleVector input = new DenseDoubleVector(instances[i]).slice(2);
@@ -328,7 +329,7 @@
} catch (IOException e) {
e.printStackTrace();
}
- SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(modelPath);
+ SmallLayeredNeuralNetwork annCopy = new SmallLayeredNeuralNetwork(new HamaConfiguration(), modelPath);
// test on instances
for (int i = 0; i < instances.length; ++i) {
DoubleVector input = new DenseDoubleVector(instances[i]).slice(2);
@@ -505,7 +506,7 @@
trainingParameters.put("training.max.iterations", "2000");
trainingParameters.put("training.batch.size", "300");
trainingParameters.put("convergence.check.interval", "1000");
- ann.train(tmpDatasetPath, trainingParameters);
+ ann.train(new HamaConfiguration(), tmpDatasetPath, trainingParameters);
long end = new Date().getTime();
@@ -614,7 +615,7 @@
trainingParameters.put("training.max.iterations", "2000");
trainingParameters.put("training.batch.size", "300");
trainingParameters.put("convergence.check.interval", "1000");
- ann.train(tmpDatasetPath, trainingParameters);
+ ann.train(new HamaConfiguration(), tmpDatasetPath, trainingParameters);
long end = new Date().getTime();
diff --git a/src/test/java/org/apache/horn/examples/NeuralNetworkTest.java b/src/test/java/org/apache/horn/examples/NeuralNetworkTest.java
index 462140c..4f44c94 100644
--- a/src/test/java/org/apache/horn/examples/NeuralNetworkTest.java
+++ b/src/test/java/org/apache/horn/examples/NeuralNetworkTest.java
@@ -18,53 +18,121 @@
package org.apache.horn.examples;
import java.io.BufferedReader;
+import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.net.URI;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
-import junit.framework.TestCase;
-
-import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
+import org.apache.hama.Constants;
+import org.apache.hama.HamaCluster;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleVector;
+import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.commons.math.FunctionFactory;
+import org.apache.horn.bsp.SmallLayeredNeuralNetwork;
/**
* Test the functionality of NeuralNetwork Example.
*
*/
-public class NeuralNetworkTest extends TestCase {
- private Configuration conf = new HamaConfiguration();
+public class NeuralNetworkTest extends HamaCluster {
+ private HamaConfiguration conf;
private FileSystem fs;
private String MODEL_PATH = "/tmp/neuralnets.model";
private String RESULT_PATH = "/tmp/neuralnets.txt";
private String SEQTRAIN_DATA = "/tmp/test-neuralnets.data";
-
+
+ public NeuralNetworkTest() {
+ conf = new HamaConfiguration();
+ conf.set("bsp.master.address", "localhost");
+ conf.setBoolean("hama.child.redirect.log.console", true);
+ conf.setBoolean("hama.messenger.runtime.compression", true);
+ assertEquals("Make sure master addr is set to localhost:", "localhost",
+ conf.get("bsp.master.address"));
+ conf.set("bsp.local.dir", "/tmp/hama-test");
+ conf.set(Constants.ZOOKEEPER_QUORUM, "localhost");
+ conf.setBoolean(Constants.FORCE_SET_BSP_TASKS, true);
+ conf.setInt(Constants.ZOOKEEPER_CLIENT_PORT, 21810);
+ conf.set("hama.sync.client.class",
+ org.apache.hama.bsp.sync.ZooKeeperSyncClientImpl.class
+ .getCanonicalName());
+ }
+
@Override
protected void setUp() throws Exception {
super.setUp();
fs = FileSystem.get(conf);
}
+ @Override
+ public void tearDown() throws Exception {
+ super.tearDown();
+ }
+
public void testNeuralnetsLabeling() throws IOException {
this.neuralNetworkTraining();
- String dataPath = "src/test/resources/neuralnets_classification_test.txt";
- String mode = "label";
+ String featureDataPath = "src/test/resources/neuralnets_classification_test.txt";
try {
- NeuralNetwork
- .main(new String[] { mode, dataPath, RESULT_PATH, MODEL_PATH });
+ SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(conf,
+ MODEL_PATH);
+
+ // process data in streaming approach
+ FileSystem fs = FileSystem.get(new URI(featureDataPath), conf);
+ BufferedReader br = new BufferedReader(new InputStreamReader(
+ fs.open(new Path(featureDataPath))));
+ Path outputPath = new Path(RESULT_PATH);
+ if (fs.exists(outputPath)) {
+ fs.delete(outputPath, true);
+ }
+ BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(
+ fs.create(outputPath)));
+
+ String line = null;
+
+ while ((line = br.readLine()) != null) {
+ if (line.trim().length() == 0) {
+ continue;
+ }
+ String[] tokens = line.trim().split(",");
+ double[] vals = new double[tokens.length];
+ for (int i = 0; i < tokens.length; ++i) {
+ vals[i] = Double.parseDouble(tokens[i]);
+ }
+ DoubleVector instance = new DenseDoubleVector(vals);
+ DoubleVector result = ann.getOutput(instance);
+ double[] arrResult = result.toArray();
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < arrResult.length; ++i) {
+ sb.append(arrResult[i]);
+ if (i != arrResult.length - 1) {
+ sb.append(",");
+ } else {
+ sb.append("\n");
+ }
+ }
+ bw.write(sb.toString());
+ }
+
+ br.close();
+ bw.close();
// compare results with ground-truth
BufferedReader groundTruthReader = new BufferedReader(new FileReader(
"src/test/resources/neuralnets_classification_label.txt"));
List<Double> groundTruthList = new ArrayList<Double>();
- String line = null;
+ line = null;
while ((line = groundTruthReader.readLine()) != null) {
groundTruthList.add(Double.parseDouble(line));
}
@@ -82,11 +150,14 @@
for (int i = 0; i < groundTruthList.size(); ++i) {
double actual = resultList.get(i);
double expected = groundTruthList.get(i);
+ LOG.info("evaluated: " + actual + ", expected: " + expected);
if (actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5) {
++correct;
}
}
- System.out.printf("Precision: %f\n", correct / total);
+
+ LOG.info("## Precision: " + (correct / total));
+ assertTrue((correct / total) > 0.5);
} catch (Exception e) {
e.printStackTrace();
@@ -97,17 +168,14 @@
}
}
+ @SuppressWarnings("deprecation")
private void neuralNetworkTraining() {
- String mode = "train";
String strTrainingDataPath = "src/test/resources/neuralnets_classification_training.txt";
int featureDimension = 8;
int labelDimension = 1;
Path sequenceTrainingDataPath = new Path(SEQTRAIN_DATA);
- Configuration conf = new Configuration();
- FileSystem fs;
try {
- fs = FileSystem.get(conf);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf,
sequenceTrainingDataPath, LongWritable.class, VectorWritable.class);
BufferedReader br = new BufferedReader(
@@ -130,11 +198,35 @@
}
try {
- NeuralNetwork.main(new String[] { mode, SEQTRAIN_DATA,
- MODEL_PATH, "" + featureDimension, "" + labelDimension });
+ int iteration = 1000;
+ double learningRate = 0.4;
+ double momemtumWeight = 0.2;
+ double regularizationWeight = 0.01;
+
+ // train the model
+ SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork();
+ ann.setLearningRate(learningRate);
+ ann.setMomemtumWeight(momemtumWeight);
+ ann.setRegularizationWeight(regularizationWeight);
+ ann.addLayer(featureDimension, false,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ ann.addLayer(featureDimension, false,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ ann.addLayer(labelDimension, true,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ ann.setCostFunction(FunctionFactory
+ .createDoubleDoubleFunction("CrossEntropy"));
+ ann.setModelPath(MODEL_PATH);
+
+ Map<String, String> trainingParameters = new HashMap<String, String>();
+ trainingParameters.put("tasks", "2");
+ trainingParameters.put("training.max.iterations", "" + iteration);
+ trainingParameters.put("training.batch.size", "300");
+ trainingParameters.put("convergence.check.interval", "1000");
+ ann.train(conf, sequenceTrainingDataPath, trainingParameters);
+
} catch (Exception e) {
e.printStackTrace();
}
}
-
}
diff --git a/src/test/java/org/apache/horn/trainer/TestNeuron.java b/src/test/java/org/apache/horn/trainer/TestNeuron.java
index 823be51..d5042a1 100644
--- a/src/test/java/org/apache/horn/trainer/TestNeuron.java
+++ b/src/test/java/org/apache/horn/trainer/TestNeuron.java
@@ -24,7 +24,7 @@
import junit.framework.TestCase;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hama.commons.math.Sigmoid;
+import org.apache.hama.HamaConfiguration;
public class TestNeuron extends TestCase {
private static double learningRate = 0.1;
@@ -35,34 +35,35 @@
Neuron<PropMessage<DoubleWritable, DoubleWritable>> {
@Override
- public void upward(
+ public void setup(HamaConfiguration conf) {
+ }
+
+ @Override
+ public void forward(
Iterable<PropMessage<DoubleWritable, DoubleWritable>> messages)
throws IOException {
double sum = 0;
for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
- sum += m.getMessage().get() * m.getWeight().get();
+ sum += m.getInput() * m.getWeight();
}
sum += (bias * theta);
-
- double output = new Sigmoid().apply(sum);
- this.setOutput(output);
- this.propagate(output);
+ feedforward(activation(sum));
}
@Override
- public void downward(
+ public void backward(
Iterable<PropMessage<DoubleWritable, DoubleWritable>> messages)
throws IOException {
for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
// Calculates error gradient for each neuron
double gradient = this.getOutput() * (1 - this.getOutput())
- * m.getMessage().get() * m.getWeight().get();
+ * m.getDelta() * m.getWeight();
// Propagates to lower layer
- this.propagate(gradient);
+ backpropagate(gradient);
// Weight corrections
- double weight = learningRate * this.getOutput() * m.getMessage().get();
+ double weight = learningRate * this.getOutput() * m.getDelta();
this.push(weight);
}
}
@@ -77,14 +78,14 @@
1.0), new DoubleWritable(0.4)));
MyNeuron n = new MyNeuron();
- n.upward(x);
+ n.forward(x);
assertEquals(0.5249791874789399, n.getOutput());
x.clear();
x.add(new PropMessage<DoubleWritable, DoubleWritable>(new DoubleWritable(
-0.1274), new DoubleWritable(-1.2)));
- n.downward(x);
+ n.backward(x);
assertEquals(-0.006688234848481696, n.getUpdate());
}
-
+
}