OPENNLP-1009 - minor improvements / fixes
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
index 155ec03..2fabecd 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
@@ -55,6 +55,7 @@
   protected final int hiddenLayerSize;
   protected final int epochs;
   protected final boolean useChars;
+  protected final int batch;
   protected final int vocabSize;
   protected final Map<String, Integer> charToIx;
   protected final Map<Integer, String> ixToChar;
@@ -71,14 +72,15 @@
   private INDArray hPrev = null; // memory state
 
   public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) {
-    this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
+    this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true);
   }
 
-  public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) {
+  public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) {
     this.learningRate = learningRate;
     this.seqLength = seqLength;
     this.hiddenLayerSize = hiddenLayerSize;
     this.epochs = epochs;
+    this.batch = batch;
     this.useChars = useChars;
 
     String[] textTokens = useChars ? toStrings(text.toCharArray()) : text.split(" ");
@@ -169,21 +171,24 @@
         System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
       }
 
-      // perform parameter update with Adagrad
-      mWxh.addi(dWxh.mul(dWxh));
-      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg))));
+      if (n% batch == 0) {
 
-      mWhh.addi(dWhh.mul(dWhh));
-      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
+        // perform parameter update with Adagrad
+        mWxh.addi(dWxh.mul(dWxh));
+        wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg))));
 
-      mWhy.addi(dWhy.mul(dWhy));
-      why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg))));
+        mWhh.addi(dWhh.mul(dWhh));
+        whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
 
-      mbh.addi(dbh.mul(dbh));
-      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+        mWhy.addi(dWhy.mul(dWhy));
+        why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg))));
 
-      mby.addi(dby.mul(dby));
-      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+        mbh.addi(dbh.mul(dbh));
+        bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+
+        mby.addi(dby.mul(dby));
+        by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+      }
 
       p += seqLength; // move data pointer
       n++; // iteration counter
@@ -244,7 +249,7 @@
     }
 
     // backward pass: compute gradients going backwards
-    INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
+    INDArray dhNext = Nd4j.zerosLike(hPrev);
     for (int t = inputs.length() - 1; t >= 0; t--) {
       INDArray dy = ps.getRow(t);
       dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y
@@ -334,17 +339,7 @@
             ", epochs=" + epochs +
             ", vocabSize=" + vocabSize +
             ", useChars=" + useChars +
-            '}';
-  }
-
-
-  public String getHyperparamsString() {
-    return getClass().getName() + "{" +
-            "wxh=" + wxh +
-            ", whh=" + whh +
-            ", why=" + why +
-            ", bh=" + bh +
-            ", by=" + by +
+            ", batch=" + batch +
             '}';
   }
 
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index e7a49d7..e6ceb9b 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -18,14 +18,6 @@
  */
 package opennlp.tools.dl;
 
-import org.apache.commons.math3.distribution.EnumeratedDistribution;
-import org.apache.commons.math3.util.Pair;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
-import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.ops.transforms.Transforms;
-
 import java.io.BufferedWriter;
 import java.io.File;
 import java.io.FileWriter;
@@ -34,6 +26,13 @@
 import java.util.LinkedList;
 import java.util.List;
 
+import org.apache.commons.math3.distribution.EnumeratedDistribution;
+import org.apache.commons.math3.util.Pair;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
 /**
  * A basic char/word-level stacked RNN model (2 hidden recurrent layers), based on Stacked RNN architecture from ICLR 2014's
  * "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar Gulcehre, Kyunghyun Cho and Yoshua Bengio
@@ -61,11 +60,11 @@
   private INDArray hPrev2 = null; // memory state
 
   public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) {
-    this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
+    this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true);
   }
 
-  public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) {
-    super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars);
+  public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) {
+    super(learningRate, seqLength, hiddenLayerSize, epochs, text, batch, useChars);
 
     wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize));
     whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
@@ -141,30 +140,32 @@
         System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
       }
 
-      // perform parameter update with Adagrad
-      mWxh.addi(dWxh.mul(dWxh));
-      wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
+      if (n % batch == 0) {
+        // perform parameter update with Adagrad
+        mWxh.addi(dWxh.mul(dWxh));
+        wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
 
-      mWxh2.addi(dWxh2.mul(dWxh2));
-      wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg))));
+        mWxh2.addi(dWxh2.mul(dWxh2));
+        wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg))));
 
-      mWhh.addi(dWhh.mul(dWhh));
-      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
+        mWhh.addi(dWhh.mul(dWhh));
+        whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
 
-      mWhh2.addi(dWhh2.mul(dWhh2));
-      whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg))));
+        mWhh2.addi(dWhh2.mul(dWhh2));
+        whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg))));
 
-      mbh2.addi(dbh2.mul(dbh2));
-      bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg))));
+        mbh2.addi(dbh2.mul(dbh2));
+        bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg))));
 
-      mWh2y.addi(dWh2y.mul(dWh2y));
-      wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg))));
+        mWh2y.addi(dWh2y.mul(dWh2y));
+        wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg))));
 
-      mbh.addi(dbh.mul(dbh));
-      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
+        mbh.addi(dbh.mul(dbh));
+        bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg))));
 
-      mby.addi(dby.mul(dby));
-      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+        mby.addi(dby.mul(dby));
+        by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg))));
+      }
 
       p += seqLength; // move data pointer
       n++; // iteration counter
@@ -176,7 +177,7 @@
    * hprev is Hx1 array of initial hidden state
    * returns the loss, gradients on model parameters and last hidden state
    */
-  private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh,  INDArray dWxh2, INDArray dWhh2, INDArray dWh2y,
+  private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y,
                          INDArray dbh, INDArray dbh2, INDArray dby) {
 
     INDArray xs = Nd4j.zeros(seqLength, vocabSize);
@@ -222,8 +223,9 @@
     }
 
     // backward pass: compute gradients going backwards
-    INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
-    INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0));
+    INDArray dhNext = Nd4j.zerosLike(hPrev);
+    INDArray dh2Next = Nd4j.zerosLike(hPrev2);
+
     for (int t = seqLength - 1; t >= 0; t--) {
       INDArray dy = ps.getRow(t);
       dy.getRow(targets.getInt(t)).subi(1); // backprop into y
@@ -249,7 +251,6 @@
       INDArray hsRow = t == 0 ? hPrev : hs.getRow(t - 1);
       dWhh.addi(dhraw.mmul(hsRow.transpose()));
       dhNext = whh.transpose().mmul(dhraw);
-
     }
 
     this.hPrev = hs.getRow(seqLength - 1);
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
index 2808f4d..57f7682 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
@@ -18,7 +18,6 @@
  */
 package opennlp.tools.dl;
 
-import java.io.FileInputStream;
 import java.io.InputStream;
 import java.util.Arrays;
 import java.util.Collection;
@@ -40,44 +39,44 @@
   private float learningRate;
   private int seqLength;
   private int hiddenLayerSize;
+  private int epochs;
+
   private Random r = new Random();
   private String text;
-  private final int epochs = 20;
   private List<String> words;
 
-  public RNNTest(float learningRate, int seqLength, int hiddenLayerSize) {
+  public RNNTest(float learningRate, int seqLength, int hiddenLayerSize, int epochs) {
     this.learningRate = learningRate;
     this.seqLength = seqLength;
     this.hiddenLayerSize = hiddenLayerSize;
+    this.epochs = epochs;
   }
 
   @Before
   public void setUp() throws Exception {
     InputStream stream = getClass().getResourceAsStream("/text/sentences.txt");
     text = IOUtils.toString(stream);
-    words = Arrays.asList(text.split(" "));
+    words = Arrays.asList(text.split("\\s"));
     stream.close();
   }
 
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][] {
-        {1e-1f, 25, 20},
-        {1e-1f, 25, 40},
-        {1e-1f, 25, 60},
+        {1e-1f, 15, 20, 5},
     });
   }
 
   @Test
   public void testVanillaCharRNNLearn() throws Exception {
-    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
+    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true);
     evaluate(rnn, true);
     rnn.serialize("target/crnn-weights-");
   }
 
   @Test
   public void testVanillaWordRNNLearn() throws Exception {
-    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs * 2, text, false);
+    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false);
     evaluate(rnn, true);
     rnn.serialize("target/wrnn-weights-");
   }
@@ -89,7 +88,6 @@
     for (int i = 0; i < 2; i++) {
       int seed = r.nextInt(rnn.getVocabSize());
       String sample = rnn.sample(seed);
-      System.out.println(sample);
       if (checkRatio && rnn.useChars) {
         String[] sampleWords = sample.split(" ");
         for (String sw : sampleWords) {
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
index ac0434c..686d603 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
@@ -39,44 +39,44 @@
   private float learningRate;
   private int seqLength;
   private int hiddenLayerSize;
+  private int epochs;
+
   private Random r = new Random();
   private String text;
-  private final int epochs = 20;
   private List<String> words;
 
-  public StackedRNNTest(float learningRate, int seqLength, int hiddenLayerSize) {
+  public StackedRNNTest(float learningRate, int seqLength, int hiddenLayerSize, int epochs) {
     this.learningRate = learningRate;
     this.seqLength = seqLength;
     this.hiddenLayerSize = hiddenLayerSize;
+    this.epochs = epochs;
   }
 
   @Before
   public void setUp() throws Exception {
     InputStream stream = getClass().getResourceAsStream("/text/sentences.txt");
     text = IOUtils.toString(stream);
-    words = Arrays.asList(text.split(" "));
+    words = Arrays.asList(text.split("\\s"));
     stream.close();
   }
 
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][] {
-        {1e-1f, 25, 20},
-        {1e-1f, 25, 40},
-        {1e-1f, 25, 60},
+        {1e-1f, 15, 20, 5},
     });
   }
 
   @Test
   public void testStackedCharRNNLearn() throws Exception {
-    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
+    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true);
     evaluate(rnn, true);
     rnn.serialize("target/scrnn-weights-");
   }
 
   @Test
   public void testStackedWordRNNLearn() throws Exception {
-    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false);
+    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false);
     evaluate(rnn, true);
     rnn.serialize("target/swrnn-weights-");
   }
@@ -88,7 +88,6 @@
     for (int i = 0; i < 2; i++) {
       int seed = r.nextInt(rnn.getVocabSize());
       String sample = rnn.sample(seed);
-      System.out.println(sample);
       if (checkRatio && rnn.useChars) {
         String[] sampleWords = sample.split(" ");
         for (String sw : sampleWords) {