HORN-34: Generalize HornJob Class
diff --git a/conf/horn-env.sh b/conf/horn-env.sh
index c60c2aa..279a8c6 100644
--- a/conf/horn-env.sh
+++ b/conf/horn-env.sh
@@ -22,4 +22,4 @@
# Set environment variables here.
# The java implementation to use. Required.
-export JAVA_HOME=/usr/lib/jvm/java-8-oracle
+# export JAVA_HOME=/usr/lib/jvm/java-8-oracle
diff --git a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
index 26e815f..85ad443 100644
--- a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
@@ -70,6 +70,8 @@
protected LearningStyle learningStyle;
+ protected float dropRate;
+
public AbstractLayeredNeuralNetwork() {
this.regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT;
this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
@@ -259,4 +261,8 @@
WritableUtils.writeEnum(output, this.learningStyle);
}
+ public void setDropRateOfInputLayer(float dropRate) {
+ this.dropRate = dropRate;
+ }
+
}
diff --git a/src/main/java/org/apache/horn/core/HornJob.java b/src/main/java/org/apache/horn/core/HornJob.java
index 3912b67..82343fe 100644
--- a/src/main/java/org/apache/horn/core/HornJob.java
+++ b/src/main/java/org/apache/horn/core/HornJob.java
@@ -28,8 +28,9 @@
public class HornJob extends BSPJob {
- LayeredNeuralNetwork neuralNetwork;
+ AbstractLayeredNeuralNetwork neuralNetwork;
+ @Deprecated
public HornJob(HamaConfiguration conf, Class<?> exampleClass)
throws IOException {
super(conf);
@@ -40,6 +41,17 @@
neuralNetwork = new LayeredNeuralNetwork();
}
+ public HornJob(HamaConfiguration conf,
+ Class<? extends AbstractLayeredNeuralNetwork> neuralNetworkClass,
+ Class<?> exampleClass)
+ throws IOException, InstantiationException, IllegalAccessException {
+ this.setJarByClass(exampleClass);
+
+ // default local file block size 10mb
+ this.getConfiguration().set("fs.local.block.size", "10358951");
+ neuralNetwork = neuralNetworkClass.newInstance();
+ }
+
public void inputLayer(int featureDimension) {
addLayer(featureDimension, null, null);
neuralNetwork.setDropRateOfInputLayer(1);
@@ -106,7 +118,7 @@
this.neuralNetwork.setRegularizationWeight(regularizationWeight);
}
- public LayeredNeuralNetwork getNeuralNetwork() {
+ public AbstractLayeredNeuralNetwork getNeuralNetwork() {
return neuralNetwork;
}
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
index 6f7aa70..0e389ca 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
@@ -87,7 +87,6 @@
private List<Neuron<?>[]> neurons = new ArrayList<Neuron<?>[]>();
- private float dropRate;
private long iterations;
public LayeredNeuralNetwork() {
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
index 5f3403b..2c45673 100644
--- a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -23,6 +23,7 @@
import org.apache.hama.HamaConfiguration;
import org.apache.horn.core.Constants.TrainingMethod;
import org.apache.horn.core.HornJob;
+import org.apache.horn.core.LayeredNeuralNetwork;
import org.apache.horn.core.Neuron;
import org.apache.horn.core.Synapse;
import org.apache.horn.funcs.CrossEntropy;
@@ -69,9 +70,9 @@
public static HornJob createJob(HamaConfiguration conf, String modelPath,
String inputPath, float learningRate, float momemtumWeight,
float regularizationWeight, int features, int hu, int labels,
- int miniBatch, int maxIteration) throws IOException {
+ int miniBatch, int maxIteration) throws IOException, InstantiationException, IllegalAccessException {
- HornJob job = new HornJob(conf, MultiLayerPerceptron.class);
+ HornJob job = new HornJob(conf, LayeredNeuralNetwork.class, MultiLayerPerceptron.class);
job.setTrainingSetPath(inputPath);
job.setModelPath(modelPath);
@@ -95,7 +96,7 @@
}
public static void main(String[] args) throws IOException,
- InterruptedException, ClassNotFoundException {
+ InterruptedException, ClassNotFoundException, NumberFormatException, InstantiationException, IllegalAccessException {
if (args.length < 9) {
System.out.println("Usage: <MODEL_PATH> <INPUT_PATH> "
+ "<LEARNING_RATE> <MOMEMTUM_WEIGHT> <REGULARIZATION_WEIGHT> "