fixed adagrad update for (s)rnn, added rmsprop to srnn
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
index 2fabecd..417b98c 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
@@ -175,19 +175,19 @@
// perform parameter update with Adagrad
mWxh.addi(dWxh.mul(dWxh));
- wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg))));
+ wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh).add(reg)));
mWhh.addi(dWhh.mul(dWhh));
- whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
+ whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(reg)));
mWhy.addi(dWhy.mul(dWhy));
- why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg))));
+ why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy).add(reg)));
mbh.addi(dbh.mul(dbh));
- bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+ bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(reg)));
mby.addi(dby.mul(dby));
- by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+ by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(reg)));
}
p += seqLength; // move data pointer
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index e6ceb9b..889fac1 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -54,18 +54,21 @@
private final INDArray bh2; // hidden2 bias
private final INDArray by; // output bias
- private final double reg = 1e-8;
+ private final double eps = 1e-4;
+ private final double decay = 0.9;
+ private final boolean rmsProp;
private INDArray hPrev = null; // memory state
private INDArray hPrev2 = null; // memory state
public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) {
- this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true);
+ this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true, false);
}
- public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) {
+ public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars, boolean rmsProp) {
super(learningRate, seqLength, hiddenLayerSize, epochs, text, batch, useChars);
+ this.rmsProp = rmsProp;
wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize));
whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
@@ -141,30 +144,58 @@
}
if (n % batch == 0) {
- // perform parameter update with Adagrad
- mWxh.addi(dWxh.mul(dWxh));
- wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
+ if (rmsProp) {
+ // perform parameter update with RMSprop
+ mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh));
+ wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
- mWxh2.addi(dWxh2.mul(dWxh2));
- wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg))));
+ mWxh2 = mWxh2.mul(decay).add(1 - decay).mul((dWxh2).mul(dWxh2));
+ wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));
- mWhh.addi(dWhh.mul(dWhh));
- whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
+ mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh));
+ whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));
- mWhh2.addi(dWhh2.mul(dWhh2));
- whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg))));
+ mWhh2 = mWhh2.mul(decay).add(1 - decay).mul((dWhh2).mul(dWhh2));
+ whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));
- mbh2.addi(dbh2.mul(dbh2));
- bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg))));
+ mbh2 = mbh2.mul(decay).add(1 - decay).mul((dbh2).mul(dbh2));
+ bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));
- mWh2y.addi(dWh2y.mul(dWh2y));
- wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg))));
+ mWh2y = mWh2y.mul(decay).add(1 - decay).mul((dWh2y).mul(dWh2y));
+ wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));
- mbh.addi(dbh.mul(dbh));
- bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+ mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh));
+ bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));
- mby.addi(dby.mul(dby));
- by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+ mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby));
+ by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
+ } else {
+ // perform parameter update with Adagrad
+
+ mWxh.addi(dWxh.mul(dWxh));
+ wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
+
+ mWxh2.addi(dWxh2.mul(dWxh2));
+ wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps)));
+
+ mWhh.addi(dWhh.mul(dWhh));
+ whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));
+
+ mWhh2.addi(dWhh2.mul(dWhh2));
+ whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps)));
+
+ mbh2.addi(dbh2.mul(dbh2));
+ bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps)));
+
+ mWh2y.addi(dWh2y.mul(dWh2y));
+ wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps)));
+
+ mbh.addi(dbh.mul(dbh));
+ bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));
+
+ mby.addi(dby.mul(dby));
+ by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
+ }
}
p += seqLength; // move data pointer
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
index 57f7682..88a9413 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
@@ -18,6 +18,7 @@
*/
package opennlp.tools.dl;
+import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
index 686d603..265426f 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
@@ -18,6 +18,7 @@
*/
package opennlp.tools.dl;
+import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
@@ -69,14 +70,14 @@
@Test
public void testStackedCharRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true);
+ RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true, true);
evaluate(rnn, true);
rnn.serialize("target/scrnn-weights-");
}
@Test
public void testStackedWordRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false);
+ RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false, false);
evaluate(rnn, true);
rnn.serialize("target/swrnn-weights-");
}