OPENNLP-1009 - minor improvements / fixes
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
index 155ec03..2fabecd 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
@@ -55,6 +55,7 @@
protected final int hiddenLayerSize;
protected final int epochs;
protected final boolean useChars;
+ protected final int batch;
protected final int vocabSize;
protected final Map<String, Integer> charToIx;
protected final Map<Integer, String> ixToChar;
@@ -71,14 +72,15 @@
private INDArray hPrev = null; // memory state
public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) {
- this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
+ this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true);
}
- public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) {
+ public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) {
this.learningRate = learningRate;
this.seqLength = seqLength;
this.hiddenLayerSize = hiddenLayerSize;
this.epochs = epochs;
+ this.batch = batch;
this.useChars = useChars;
String[] textTokens = useChars ? toStrings(text.toCharArray()) : text.split(" ");
@@ -169,21 +171,24 @@
System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
}
- // perform parameter update with Adagrad
- mWxh.addi(dWxh.mul(dWxh));
- wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg))));
+ if (n% batch == 0) {
- mWhh.addi(dWhh.mul(dWhh));
- whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
+ // perform parameter update with Adagrad
+ mWxh.addi(dWxh.mul(dWxh));
+ wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg))));
- mWhy.addi(dWhy.mul(dWhy));
- why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg))));
+ mWhh.addi(dWhh.mul(dWhh));
+ whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
- mbh.addi(dbh.mul(dbh));
- bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+ mWhy.addi(dWhy.mul(dWhy));
+ why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg))));
- mby.addi(dby.mul(dby));
- by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+ mbh.addi(dbh.mul(dbh));
+ bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+
+ mby.addi(dby.mul(dby));
+ by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+ }
p += seqLength; // move data pointer
n++; // iteration counter
@@ -244,7 +249,7 @@
}
// backward pass: compute gradients going backwards
- INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
+ INDArray dhNext = Nd4j.zerosLike(hPrev);
for (int t = inputs.length() - 1; t >= 0; t--) {
INDArray dy = ps.getRow(t);
dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y
@@ -334,17 +339,7 @@
", epochs=" + epochs +
", vocabSize=" + vocabSize +
", useChars=" + useChars +
- '}';
- }
-
-
- public String getHyperparamsString() {
- return getClass().getName() + "{" +
- "wxh=" + wxh +
- ", whh=" + whh +
- ", why=" + why +
- ", bh=" + bh +
- ", by=" + by +
+ ", batch=" + batch +
'}';
}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index e7a49d7..e6ceb9b 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -18,14 +18,6 @@
*/
package opennlp.tools.dl;
-import org.apache.commons.math3.distribution.EnumeratedDistribution;
-import org.apache.commons.math3.util.Pair;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
-import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.ops.transforms.Transforms;
-
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
@@ -34,6 +26,13 @@
import java.util.LinkedList;
import java.util.List;
+import org.apache.commons.math3.distribution.EnumeratedDistribution;
+import org.apache.commons.math3.util.Pair;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
/**
* A basic char/word-level stacked RNN model (2 hidden recurrent layers), based on Stacked RNN architecture from ICLR 2014's
* "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar Gulcehre, Kyunghyun Cho and Yoshua Bengio
@@ -61,11 +60,11 @@
private INDArray hPrev2 = null; // memory state
public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) {
- this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
+ this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true);
}
- public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) {
- super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars);
+ public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) {
+ super(learningRate, seqLength, hiddenLayerSize, epochs, text, batch, useChars);
wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize));
whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
@@ -141,30 +140,32 @@
System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
}
- // perform parameter update with Adagrad
- mWxh.addi(dWxh.mul(dWxh));
- wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
+ if (n % batch == 0) {
+ // perform parameter update with Adagrad
+ mWxh.addi(dWxh.mul(dWxh));
+ wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
- mWxh2.addi(dWxh2.mul(dWxh2));
- wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg))));
+ mWxh2.addi(dWxh2.mul(dWxh2));
+ wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg))));
- mWhh.addi(dWhh.mul(dWhh));
- whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
+ mWhh.addi(dWhh.mul(dWhh));
+ whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
- mWhh2.addi(dWhh2.mul(dWhh2));
- whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg))));
+ mWhh2.addi(dWhh2.mul(dWhh2));
+ whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg))));
- mbh2.addi(dbh2.mul(dbh2));
- bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg))));
+ mbh2.addi(dbh2.mul(dbh2));
+ bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg))));
- mWh2y.addi(dWh2y.mul(dWh2y));
- wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg))));
+ mWh2y.addi(dWh2y.mul(dWh2y));
+ wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg))));
- mbh.addi(dbh.mul(dbh));
- bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+ mbh.addi(dbh.mul(dbh));
+ bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
- mby.addi(dby.mul(dby));
- by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+ mby.addi(dby.mul(dby));
+ by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+ }
p += seqLength; // move data pointer
n++; // iteration counter
@@ -176,7 +177,7 @@
* hprev is Hx1 array of initial hidden state
* returns the loss, gradients on model parameters and last hidden state
*/
- private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y,
+ private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y,
INDArray dbh, INDArray dbh2, INDArray dby) {
INDArray xs = Nd4j.zeros(seqLength, vocabSize);
@@ -222,8 +223,9 @@
}
// backward pass: compute gradients going backwards
- INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
- INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0));
+ INDArray dhNext = Nd4j.zerosLike(hPrev);
+ INDArray dh2Next = Nd4j.zerosLike(hPrev2);
+
for (int t = seqLength - 1; t >= 0; t--) {
INDArray dy = ps.getRow(t);
dy.getRow(targets.getInt(t)).subi(1); // backprop into y
@@ -249,7 +251,6 @@
INDArray hsRow = t == 0 ? hPrev : hs.getRow(t - 1);
dWhh.addi(dhraw.mmul(hsRow.transpose()));
dhNext = whh.transpose().mmul(dhraw);
-
}
this.hPrev = hs.getRow(seqLength - 1);
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
index 2808f4d..57f7682 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
@@ -18,7 +18,6 @@
*/
package opennlp.tools.dl;
-import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
@@ -40,44 +39,44 @@
private float learningRate;
private int seqLength;
private int hiddenLayerSize;
+ private int epochs;
+
private Random r = new Random();
private String text;
- private final int epochs = 20;
private List<String> words;
- public RNNTest(float learningRate, int seqLength, int hiddenLayerSize) {
+ public RNNTest(float learningRate, int seqLength, int hiddenLayerSize, int epochs) {
this.learningRate = learningRate;
this.seqLength = seqLength;
this.hiddenLayerSize = hiddenLayerSize;
+ this.epochs = epochs;
}
@Before
public void setUp() throws Exception {
InputStream stream = getClass().getResourceAsStream("/text/sentences.txt");
text = IOUtils.toString(stream);
- words = Arrays.asList(text.split(" "));
+ words = Arrays.asList(text.split("\\s"));
stream.close();
}
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1e-1f, 25, 20},
- {1e-1f, 25, 40},
- {1e-1f, 25, 60},
+ {1e-1f, 15, 20, 5},
});
}
@Test
public void testVanillaCharRNNLearn() throws Exception {
- RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
+ RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true);
evaluate(rnn, true);
rnn.serialize("target/crnn-weights-");
}
@Test
public void testVanillaWordRNNLearn() throws Exception {
- RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs * 2, text, false);
+ RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false);
evaluate(rnn, true);
rnn.serialize("target/wrnn-weights-");
}
@@ -89,7 +88,6 @@
for (int i = 0; i < 2; i++) {
int seed = r.nextInt(rnn.getVocabSize());
String sample = rnn.sample(seed);
- System.out.println(sample);
if (checkRatio && rnn.useChars) {
String[] sampleWords = sample.split(" ");
for (String sw : sampleWords) {
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
index ac0434c..686d603 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
@@ -39,44 +39,44 @@
private float learningRate;
private int seqLength;
private int hiddenLayerSize;
+ private int epochs;
+
private Random r = new Random();
private String text;
- private final int epochs = 20;
private List<String> words;
- public StackedRNNTest(float learningRate, int seqLength, int hiddenLayerSize) {
+ public StackedRNNTest(float learningRate, int seqLength, int hiddenLayerSize, int epochs) {
this.learningRate = learningRate;
this.seqLength = seqLength;
this.hiddenLayerSize = hiddenLayerSize;
+ this.epochs = epochs;
}
@Before
public void setUp() throws Exception {
InputStream stream = getClass().getResourceAsStream("/text/sentences.txt");
text = IOUtils.toString(stream);
- words = Arrays.asList(text.split(" "));
+ words = Arrays.asList(text.split("\\s"));
stream.close();
}
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1e-1f, 25, 20},
- {1e-1f, 25, 40},
- {1e-1f, 25, 60},
+ {1e-1f, 15, 20, 5},
});
}
@Test
public void testStackedCharRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
+ RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true);
evaluate(rnn, true);
rnn.serialize("target/scrnn-weights-");
}
@Test
public void testStackedWordRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false);
+ RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false);
evaluate(rnn, true);
rnn.serialize("target/swrnn-weights-");
}
@@ -88,7 +88,6 @@
for (int i = 0; i < 2; i++) {
int seed = r.nextInt(rnn.getVocabSize());
String sample = rnn.sample(seed);
- System.out.println(sample);
if (checkRatio && rnn.useChars) {
String[] sampleWords = sample.split(" ");
for (String sw : sampleWords) {