Merge pull request #20 from tteofili/OPENNLP-1009a
OPENNLP-1009 - upgrade to dl4j 1.0.0-beta2
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index cfb1a1b..829cf6a 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -26,7 +26,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
- <nd4j.version>0.9.1</nd4j.version>
+ <nd4j.version>1.0.0-beta2</nd4j.version>
</properties>
<dependencies>
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
index 86af123..4f7b5c3 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
@@ -233,11 +233,6 @@
}
@Override
- public int totalExamples() {
- return this.records.size();
- }
-
- @Override
public int inputColumns() {
return this.embedder.getVectorSize();
}
@@ -272,16 +267,6 @@
}
@Override
- public int cursor() {
- return this.cursor;
- }
-
- @Override
- public int numExamples() {
- return totalExamples();
- }
-
- @Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException();
}
@@ -298,7 +283,7 @@
@Override
public boolean hasNext() {
- return cursor < totalExamples() - 1;
+ return cursor < this.records.size() - 1;
}
@Override
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NameFinderDL.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NameFinderDL.java
index 7547196..3a0ad54 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/NameFinderDL.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NameFinderDL.java
@@ -42,6 +42,7 @@
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import opennlp.tools.namefind.BioCodec;
@@ -159,12 +160,9 @@
int layerSize = 256;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
- .updater(Updater.RMSPROP)
- .regularization(true).l2(0.001)
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ .updater(new RmsProp(0.01)).l2(0.001)
.weightInit(WeightInit.XAVIER)
- // .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
- .learningRate(0.01)
.list()
.layer(0, new GravesLSTM.Builder().nIn(vectorSize).nOut(layerSize)
.activation(Activation.TANH).build())
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NameSampleDataSetIterator.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NameSampleDataSetIterator.java
index a420220..d6d171a 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/NameSampleDataSetIterator.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NameSampleDataSetIterator.java
@@ -140,7 +140,7 @@
}
if (sample != null) {
- INDArray feature = sample.getFeatureMatrix();
+ INDArray feature = sample.getFeatures();
features.put(new INDArrayIndex[] {NDArrayIndex.point(i)}, feature.get(NDArrayIndex.point(0)));
feature.get(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.all(),
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
index 299a742..9e91484 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
@@ -60,7 +60,7 @@
INDArray seqFeatures = this.model.getGloves().embed(text, this.model.getMaxSeqLen());
INDArray networkOutput = this.model.getNetwork().output(seqFeatures);
- int timeSeriesLength = networkOutput.size(2);
+ long timeSeriesLength = networkOutput.size(2);
INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
index 9ce3a3f..697bff0 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
@@ -135,12 +135,11 @@
//TODO: the below network params should be configurable from CLI or settings file
//Set up network configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
- .updater(new RmsProp(0.9)) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
- .regularization(true).l2(1e-5)
+ .updater(new RmsProp(args.learningRate)) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
+ .l2(1e-5)
.weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0)
- .learningRate(args.learningRate)
.list()
.layer(0, new GravesLSTM.Builder()
.nIn(vectorSize)
@@ -177,8 +176,8 @@
public void train(int nEpochs, DataReader train, DataReader validation) {
assert model != null;
assert train != null;
- LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
- train.totalExamples(), validation == null ? null : validation.totalExamples());
+// LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
+// train.(), validation == null ? null : validation.totalExamples());
for (int i = 0; i < nEpochs; i++) {
model.getNetwork().fit(train);
train.reset();
@@ -190,7 +189,7 @@
Evaluation evaluation = new Evaluation();
while (validation.hasNext()) {
DataSet t = validation.next();
- INDArray features = t.getFeatureMatrix();
+ INDArray features = t.getFeatures();
INDArray labels = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
index e297cc5..7547cce 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java
@@ -35,6 +35,7 @@
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
@@ -241,7 +242,7 @@
ys = init(inputs.length(), yst.shape());
}
ys.putRow(t, yst);
- INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars
+ INDArray pst = Nd4j.getExecutioner().execAndReturn(new OldSoftMax(yst)); // probabilities for next chars
if (ps == null) {
ps = init(inputs.length(), pst.shape());
}
@@ -251,7 +252,7 @@
// backward pass: compute gradients going backwards
INDArray dhNext = Nd4j.zerosLike(hPrev);
- for (int t = inputs.length() - 1; t >= 0; t--) {
+ for (int t = (int) (inputs.length() - 1); t >= 0; t--) {
INDArray dy = ps.getRow(t);
dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y
INDArray hst = hs.getRow(t);
@@ -271,9 +272,9 @@
return loss;
}
- protected INDArray init(int t, int[] aShape) {
+ protected INDArray init(long t, long[] aShape) {
INDArray as;
- int[] shape = new int[1 + aShape.length];
+ long[] shape = new long[1 + aShape.length];
shape[0] = t;
System.arraycopy(aShape, 0, shape, 1, aShape.length);
as = Nd4j.create(shape);
@@ -295,7 +296,7 @@
for (int t = 0; t < sampleSize; t++) {
h = Transforms.tanh(wxh.mmul(x).add(whh.mmul(h)).add(bh));
INDArray y = (why.mmul(h)).add(by);
- INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel();
+ INDArray pm = Nd4j.getExecutioner().execAndReturn(new OldSoftMax(y)).ravel();
List<Pair<Integer, Double>> d = new LinkedList<>();
for (int pi = 0; pi < vocabSize; pi++) {
@@ -321,11 +322,12 @@
NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
while (ndIndexIterator.hasNext()) {
- int[] next = ndIndexIterator.next();
+ long[] next = ndIndexIterator.next();
if (!useChars && txt.length() > 0) {
txt.append(' ');
}
- txt.append(ixToChar.get(ixes.getInt(next)));
+ int aDouble = (int) ixes.getDouble(next);
+ txt.append(ixToChar.get(aDouble));
}
return txt.toString();
}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index fe56d8f..6a187c2 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -29,6 +29,7 @@
import org.apache.commons.math3.distribution.EnumeratedDistribution;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
@@ -251,7 +252,7 @@
}
ys.putRow(t, yst);
- INDArray pst = Nd4j.getExecutioner().execAndReturn(new ReplaceNans(Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)), 0d)); // probabilities for next chars
+ INDArray pst = Nd4j.getExecutioner().execAndReturn(new ReplaceNans(Nd4j.getExecutioner().execAndReturn(new OldSoftMax(yst)), 0d)); // probabilities for next chars
if (ps == null) {
ps = init(seqLength, pst.shape());
}
@@ -312,7 +313,7 @@
h = Transforms.tanh((wxh.mmul(x)).add(whh.mmul(h)).add(bh));
h2 = Transforms.tanh((wxh2.mmul(h)).add(whh2.mmul(h2)).add(bh2));
INDArray y = wh2y.mmul(h2).add(by);
- INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel();
+ INDArray pm = Nd4j.getExecutioner().execAndReturn(new OldSoftMax(y)).ravel();
List<Pair<Integer, Double>> d = new LinkedList<>();
for (int pi = 0; pi < vocabSize; pi++) {
diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
index 6a61642..8c81565 100644
--- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
+++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java
@@ -63,7 +63,7 @@
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1e-2f, 25, 50, 4},
+ {1e-3f, 25, 50, 4},
});
}