OPENNLP-1009 - minor updates to (s)rnn parameters, rnn now using rmsprop
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
index 417b98c..e297cc5 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
@@ -50,7 +50,7 @@
 public class RNN {
 
   // hyperparameters
-  protected final float learningRate; // size of hidden layer of neurons
+  protected float learningRate;
   protected final int seqLength; // no. of steps to unroll the RNN for
   protected final int hiddenLayerSize;
   protected final int epochs;
@@ -60,7 +60,8 @@
   protected final Map<String, Integer> charToIx;
   protected final Map<Integer, String> ixToChar;
   protected final List<String> data;
-  private final static double reg = 1e-8;
+  private final static double eps = 1e-8;
+  private final static double decay = 0.9;
 
   // model parameters
   private final INDArray wxh; // input to hidden
@@ -171,23 +172,23 @@
         System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
       }
 
-      if (n% batch == 0) {
+      if (n % batch == 0) {
 
-        // perform parameter update with Adagrad
-        mWxh.addi(dWxh.mul(dWxh));
-        wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh).add(reg)));
+        // perform parameter update with RMSprop
+        mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh));
+        wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
 
-        mWhh.addi(dWhh.mul(dWhh));
-        whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(reg)));
+        mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh));
+        whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));
 
-        mWhy.addi(dWhy.mul(dWhy));
-        why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy).add(reg)));
+        mWhy = mWhy.mul(decay).add(1 - decay).mul((dWhy).mul(dWhy));
+        why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy).add(eps)));
 
-        mbh.addi(dbh.mul(dbh));
-        bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(reg)));
+        mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh));
+        bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));
 
-        mby.addi(dby.mul(dby));
-        by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(reg)));
+        mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby));
+        by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
       }
 
       p += seqLength; // move data pointer
@@ -245,7 +246,7 @@
         ps = init(inputs.length(), pst.shape());
       }
       ps.putRow(t, pst);
-      loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax (cross-entropy loss)
+      loss += -Math.log(pst.getDouble(targets.getInt(t),0)); // softmax (cross-entropy loss)
     }
 
     // backward pass: compute gradients going backwards
@@ -286,7 +287,7 @@
 
     INDArray x = Nd4j.zeros(vocabSize, 1);
     x.putScalar(seedIx, 1);
-    int sampleSize = 2 * seqLength;
+    int sampleSize = 144;
     INDArray ixes = Nd4j.create(sampleSize);
 
     INDArray h = hPrev.dup();
@@ -300,13 +301,16 @@
       for (int pi = 0; pi < vocabSize; pi++) {
         d.add(new Pair<>(pi, pm.getDouble(0, pi)));
       }
-      EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d);
+      try {
+        EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d);
 
-      int ix = distribution.sample();
+        int ix = distribution.sample();
 
-      x = Nd4j.zeros(vocabSize, 1);
-      x.putScalar(ix, 1);
-      ixes.putScalar(t, ix);
+        x = Nd4j.zeros(vocabSize, 1);
+        x.putScalar(ix, 1);
+        ixes.putScalar(t, ix);
+      } catch (Exception e) {
+      }
     }
 
     return getSampleString(ixes);
@@ -333,14 +337,14 @@
   @Override
   public String toString() {
     return getClass().getName() + "{" +
-            "learningRate=" + learningRate +
-            ", seqLength=" + seqLength +
-            ", hiddenLayerSize=" + hiddenLayerSize +
-            ", epochs=" + epochs +
-            ", vocabSize=" + vocabSize +
-            ", useChars=" + useChars +
-            ", batch=" + batch +
-            '}';
+        "learningRate=" + learningRate +
+        ", seqLength=" + seqLength +
+        ", hiddenLayerSize=" + hiddenLayerSize +
+        ", epochs=" + epochs +
+        ", vocabSize=" + vocabSize +
+        ", useChars=" + useChars +
+        ", batch=" + batch +
+        '}';
   }
 
   public void serialize(String prefix) throws IOException {
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index e9a5f7e..fe56d8f 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -55,8 +55,8 @@
   private final INDArray bh2; // hidden2 bias
   private final INDArray by; // output bias
 
-  private final double eps = 1e-4;
-  private final double decay = 0.9;
+  private final double eps = 1e-8;
+  private final double decay = 0.95;
   private final boolean rmsProp;
 
   private INDArray hPrev = null; // memory state
@@ -137,9 +137,14 @@
 
       // forward seqLength characters through the net and fetch gradient
       double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby);
-      smoothLoss = smoothLoss * 0.999 + loss * 0.001;
+      double newLoss = smoothLoss * 0.999 + loss * 0.001;
+
+      if (newLoss > smoothLoss) {
+        learningRate *= 0.999 ;
+      }
+      smoothLoss = newLoss;
       if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
-        System.out.println("loss is " + smoothLoss + " (over/underflow occured, try adjusting hyperparameters)");
+        System.out.println("loss is " + smoothLoss + "(" + loss + ") (over/underflow occurred, try adjusting hyperparameters)");
         break;
       }
       if (n % 100 == 0) {
@@ -252,7 +257,7 @@
       }
       ps.putRow(t, pst);
 
-      loss += -Math.log(pst.getDouble(targets.getInt(t),0)); // softmax (cross-entropy loss)
+      loss += -Math.log(pst.getDouble(targets.getInt(t), 0)); // softmax (cross-entropy loss)
     }
 
     // backward pass: compute gradients going backwards
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
index 88a9413..09a4b48 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
@@ -18,7 +18,6 @@
  */
 package opennlp.tools.dl;
 
-import java.io.FileInputStream;
 import java.io.InputStream;
 import java.util.Arrays;
 import java.util.Collection;
@@ -64,24 +63,17 @@
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][] {
-        {1e-1f, 15, 20, 5},
+        {1e-3f, 100, 300, 500},
     });
   }
 
   @Test
   public void testVanillaCharRNNLearn() throws Exception {
-    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true);
+    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 10, true);
     evaluate(rnn, true);
     rnn.serialize("target/crnn-weights-");
   }
 
-  @Test
-  public void testVanillaWordRNNLearn() throws Exception {
-    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false);
-    evaluate(rnn, true);
-    rnn.serialize("target/wrnn-weights-");
-  }
-
   private void evaluate(RNN rnn, boolean checkRatio) {
     System.out.println(rnn);
     rnn.learn();
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
index 265426f..dbf5a4b 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
@@ -18,7 +18,6 @@
  */
 package opennlp.tools.dl;
 
-import java.io.FileInputStream;
 import java.io.InputStream;
 import java.util.Arrays;
 import java.util.Collection;
@@ -55,7 +54,7 @@
 
   @Before
   public void setUp() throws Exception {
-    InputStream stream = getClass().getResourceAsStream("/text/sentences.txt");
+    InputStream stream = getClass().getResourceAsStream("/text/queries.txt");
     text = IOUtils.toString(stream);
     words = Arrays.asList(text.split("\\s"));
     stream.close();
@@ -64,24 +63,17 @@
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][] {
-        {1e-1f, 15, 20, 5},
+        {1e-3f, 100, 300, 500},
     });
   }
 
   @Test
   public void testStackedCharRNNLearn() throws Exception {
-    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true, true);
+    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 20, true, true);
     evaluate(rnn, true);
     rnn.serialize("target/scrnn-weights-");
   }
 
-  @Test
-  public void testStackedWordRNNLearn() throws Exception {
-    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false, false);
-    evaluate(rnn, true);
-    rnn.serialize("target/swrnn-weights-");
-  }
-
   private void evaluate(RNN rnn, boolean checkRatio) {
     System.out.println(rnn);
     rnn.learn();