OPENNLP-1009 - less epochs for (s)RNNs tests
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
index 09a4b48..bc3904f 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java
@@ -63,7 +63,7 @@
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1e-3f, 100, 300, 500},
+ {1e-3f, 25, 50, 5},
});
}
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
index c716f9d..6a61642 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
@@ -63,13 +63,13 @@
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1e-3f, 100, 300, 500},
+ {1e-2f, 25, 50, 4},
});
}
@Test
public void testStackedCharRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 20, true, true);
+ RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 10, true, true);
evaluate(rnn, true);
rnn.serialize("target/scrnn-weights-");
}