OPENNLP-1009 - minor updates to (s)rnn parameters, rnn now using rmsprop
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
index 417b98c..e297cc5 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
@@ -50,7 +50,7 @@
public class RNN {
// hyperparameters
- protected final float learningRate; // size of hidden layer of neurons
+ protected float learningRate;
protected final int seqLength; // no. of steps to unroll the RNN for
protected final int hiddenLayerSize;
protected final int epochs;
@@ -60,7 +60,8 @@
protected final Map<String, Integer> charToIx;
protected final Map<Integer, String> ixToChar;
protected final List<String> data;
- private final static double reg = 1e-8;
+ private final static double eps = 1e-8;
+ private final static double decay = 0.9;
// model parameters
private final INDArray wxh; // input to hidden
@@ -171,23 +172,23 @@
System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
}
- if (n% batch == 0) {
+ if (n % batch == 0) {
- // perform parameter update with Adagrad
- mWxh.addi(dWxh.mul(dWxh));
- wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh).add(reg)));
+ // perform parameter update with RMSprop
+ mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh));
+ wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
- mWhh.addi(dWhh.mul(dWhh));
- whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(reg)));
+ mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh));
+ whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));
- mWhy.addi(dWhy.mul(dWhy));
- why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy).add(reg)));
+ mWhy = mWhy.mul(decay).add(1 - decay).mul((dWhy).mul(dWhy));
+ why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy).add(eps)));
- mbh.addi(dbh.mul(dbh));
- bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(reg)));
+ mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh));
+ bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));
- mby.addi(dby.mul(dby));
- by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(reg)));
+ mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby));
+ by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
}
p += seqLength; // move data pointer
@@ -245,7 +246,7 @@
ps = init(inputs.length(), pst.shape());
}
ps.putRow(t, pst);
- loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax (cross-entropy loss)
+ loss += -Math.log(pst.getDouble(targets.getInt(t),0)); // softmax (cross-entropy loss)
}
// backward pass: compute gradients going backwards
@@ -286,7 +287,7 @@
INDArray x = Nd4j.zeros(vocabSize, 1);
x.putScalar(seedIx, 1);
- int sampleSize = 2 * seqLength;
+ int sampleSize = 144;
INDArray ixes = Nd4j.create(sampleSize);
INDArray h = hPrev.dup();
@@ -300,13 +301,16 @@
for (int pi = 0; pi < vocabSize; pi++) {
d.add(new Pair<>(pi, pm.getDouble(0, pi)));
}
- EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d);
+ try {
+ EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d);
- int ix = distribution.sample();
+ int ix = distribution.sample();
- x = Nd4j.zeros(vocabSize, 1);
- x.putScalar(ix, 1);
- ixes.putScalar(t, ix);
+ x = Nd4j.zeros(vocabSize, 1);
+ x.putScalar(ix, 1);
+ ixes.putScalar(t, ix);
+ } catch (Exception e) {
+ }
}
return getSampleString(ixes);
@@ -333,14 +337,14 @@
@Override
public String toString() {
return getClass().getName() + "{" +
- "learningRate=" + learningRate +
- ", seqLength=" + seqLength +
- ", hiddenLayerSize=" + hiddenLayerSize +
- ", epochs=" + epochs +
- ", vocabSize=" + vocabSize +
- ", useChars=" + useChars +
- ", batch=" + batch +
- '}';
+ "learningRate=" + learningRate +
+ ", seqLength=" + seqLength +
+ ", hiddenLayerSize=" + hiddenLayerSize +
+ ", epochs=" + epochs +
+ ", vocabSize=" + vocabSize +
+ ", useChars=" + useChars +
+ ", batch=" + batch +
+ '}';
}
public void serialize(String prefix) throws IOException {
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index e9a5f7e..fe56d8f 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -55,8 +55,8 @@
private final INDArray bh2; // hidden2 bias
private final INDArray by; // output bias
- private final double eps = 1e-4;
- private final double decay = 0.9;
+ private final double eps = 1e-8;
+ private final double decay = 0.95;
private final boolean rmsProp;
private INDArray hPrev = null; // memory state
@@ -137,9 +137,14 @@
// forward seqLength characters through the net and fetch gradient
double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby);
- smoothLoss = smoothLoss * 0.999 + loss * 0.001;
+ double newLoss = smoothLoss * 0.999 + loss * 0.001;
+
+ if (newLoss > smoothLoss) {
+ learningRate *= 0.999 ;
+ }
+ smoothLoss = newLoss;
if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
- System.out.println("loss is " + smoothLoss + " (over/underflow occured, try adjusting hyperparameters)");
+ System.out.println("loss is " + smoothLoss + "(" + loss + ") (over/underflow occurred, try adjusting hyperparameters)");
break;
}
if (n % 100 == 0) {
@@ -252,7 +257,7 @@
}
ps.putRow(t, pst);
- loss += -Math.log(pst.getDouble(targets.getInt(t),0)); // softmax (cross-entropy loss)
+ loss += -Math.log(pst.getDouble(targets.getInt(t), 0)); // softmax (cross-entropy loss)
}
// backward pass: compute gradients going backwards
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
index 88a9413..09a4b48 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
@@ -18,7 +18,6 @@
*/
package opennlp.tools.dl;
-import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
@@ -64,24 +63,17 @@
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1e-1f, 15, 20, 5},
+ {1e-3f, 100, 300, 500},
});
}
@Test
public void testVanillaCharRNNLearn() throws Exception {
- RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true);
+ RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 10, true);
evaluate(rnn, true);
rnn.serialize("target/crnn-weights-");
}
- @Test
- public void testVanillaWordRNNLearn() throws Exception {
- RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false);
- evaluate(rnn, true);
- rnn.serialize("target/wrnn-weights-");
- }
-
private void evaluate(RNN rnn, boolean checkRatio) {
System.out.println(rnn);
rnn.learn();
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
index 265426f..dbf5a4b 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
@@ -18,7 +18,6 @@
*/
package opennlp.tools.dl;
-import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
@@ -55,7 +54,7 @@
@Before
public void setUp() throws Exception {
- InputStream stream = getClass().getResourceAsStream("/text/sentences.txt");
+ InputStream stream = getClass().getResourceAsStream("/text/queries.txt");
text = IOUtils.toString(stream);
words = Arrays.asList(text.split("\\s"));
stream.close();
@@ -64,24 +63,17 @@
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1e-1f, 15, 20, 5},
+ {1e-3f, 100, 300, 500},
});
}
@Test
public void testStackedCharRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true, true);
+ RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 20, true, true);
evaluate(rnn, true);
rnn.serialize("target/scrnn-weights-");
}
- @Test
- public void testStackedWordRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false, false);
- evaluate(rnn, true);
- rnn.serialize("target/swrnn-weights-");
- }
-
private void evaluate(RNN rnn, boolean checkRatio) {
System.out.println(rnn);
rnn.learn();