Add first draft of dl name finder
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
new file mode 100644
index 0000000..f8a6679
--- /dev/null
+++ b/opennlp-dl/pom.xml
@@ -0,0 +1,56 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <groupId>burn</groupId>
+ <artifactId>dl4jtest</artifactId>
+ <version>1.0-SNAPSHOT</version>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.opennlp</groupId>
+ <artifactId>opennlp-tools</artifactId>
+ <version>1.7.2</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.deeplearning4j</groupId>
+ <artifactId>deeplearning4j-core</artifactId>
+ <version>0.7.2</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.nd4j</groupId>
+ <artifactId>nd4j-native-platform</artifactId>
+ <!-- artifactId>nd4j-cuda-8.0-platform</artifactId -->
+ <version>0.7.2</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.deeplearning4j</groupId>
+ <artifactId>deeplearning4j-nlp</artifactId>
+ <version>0.7.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-simple</artifactId>
+ <version>1.7.12</version>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>3.5.1</version>
+ <configuration>
+ <source>1.8</source>
+ <target>1.8</target>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/opennlp-dl/src/main/java/NameFinderDL.java b/opennlp-dl/src/main/java/NameFinderDL.java
new file mode 100644
index 0000000..1184a06
--- /dev/null
+++ b/opennlp-dl/src/main/java/NameFinderDL.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
+import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+import opennlp.tools.namefind.BioCodec;
+import opennlp.tools.namefind.NameSample;
+import opennlp.tools.namefind.NameSampleDataStream;
+import opennlp.tools.namefind.TokenNameFinder;
+import opennlp.tools.namefind.TokenNameFinderEvaluator;
+import opennlp.tools.util.MarkableFileInputStreamFactory;
+import opennlp.tools.util.ObjectStream;
+import opennlp.tools.util.PlainTextByLineStream;
+import opennlp.tools.util.Span;
+
+// https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/word2vecsentiment/Word2VecSentimentRNN.java
+public class NameFinderDL implements TokenNameFinder {
+
+ private final MultiLayerNetwork network;
+ private final WordVectors wordVectors;
+ private int windowSize;
+ private String[] labels;
+
+ public NameFinderDL(MultiLayerNetwork network, WordVectors wordVectors, int windowSize,
+ String[] labels) {
+ this.network = network;
+ this.wordVectors = wordVectors;
+ this.windowSize = windowSize;
+ this.labels = labels;
+ }
+
+ static List<INDArray> mapToFeatureMatrices(WordVectors wordVectors, String[] tokens, int windowSize) {
+
+ List<INDArray> matrices = new ArrayList<>();
+
+ // TODO: Dont' hard code word vector dimension ...
+
+ for (int i = 0; i < tokens.length; i++) {
+ INDArray features = Nd4j.create(1, 300, windowSize);
+ for (int vectorIndex = 0; vectorIndex < windowSize; vectorIndex++) {
+ int tokenIndex = i + vectorIndex - ((windowSize - 1) / 2);
+ if (tokenIndex >= 0 && tokenIndex < tokens.length) {
+ String token = tokens[tokenIndex];
+ double[] wv = wordVectors.getWordVector(token);
+ if (wv != null) {
+ INDArray vector = wordVectors.getWordVectorMatrix(token);
+ features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(),
+ NDArrayIndex.point(vectorIndex)}, vector);
+ }
+ }
+ }
+ matrices.add(features);
+ }
+
+ return matrices;
+ }
+
+ static List<INDArray> mapToLabelVectors(NameSample sample, int windowSize, String[] labelStrings) {
+
+ Map<String, Integer> labelToIndex = IntStream.range(0, labelStrings.length).boxed()
+ .collect(Collectors.toMap(i -> labelStrings[i], i -> i));
+
+ List<INDArray> vectors = new ArrayList<INDArray>();
+
+ for (int i = 0; i < sample.getSentence().length; i++) {
+ // encode the outcome as one-hot-representation
+ String outcomes[] =
+ new BioCodec().encode(sample.getNames(), sample.getSentence().length);
+
+ INDArray labels = Nd4j.create(1, labelStrings.length, windowSize);
+ labels.putScalar(new int[]{0, labelToIndex.get(outcomes[i]), windowSize - 1}, 1.0d);
+ vectors.add(labels);
+ }
+
+ return vectors;
+ }
+
+ private static int max(INDArray array) {
+ int best = 0;
+ for (int i = 0; i < array.size(0); i++) {
+ if (array.getDouble(i) > array.getDouble(best)) {
+ best = i;
+ }
+ }
+ return best;
+ }
+
+ @Override
+ public Span[] find(String[] tokens) {
+ List<INDArray> featureMartrices = mapToFeatureMatrices(wordVectors, tokens, windowSize);
+
+ String[] outcomes = new String[tokens.length];
+ for (int i = 0; i < tokens.length; i++) {
+ INDArray predictionMatrix = network.output(featureMartrices.get(i), false);
+ INDArray outcomeVector = predictionMatrix.get(NDArrayIndex.point(0), NDArrayIndex.all(),
+ NDArrayIndex.point(windowSize - 1));
+
+ outcomes[i] = labels[max(outcomeVector)];
+ }
+
+ // Delete invalid spans ...
+ for (int i = 0; i < outcomes.length; i++) {
+ if (outcomes[i].endsWith("cont") && (i == 0 || "other".equals(outcomes[i - 1]))) {
+ outcomes[i] = "other";
+ }
+ }
+
+ return new BioCodec().decode(Arrays.asList(outcomes));
+ }
+
+ @Override
+ public void clearAdaptiveData() {
+ }
+
+ public static MultiLayerNetwork train(WordVectors wordVectors, ObjectStream<NameSample> samples,
+ int epochs, int windowSize, String[] labels) throws IOException {
+ int vectorSize = 300;
+ int layerSize = 256;
+
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
+ .updater(Updater.RMSPROP)
+ .regularization(true).l2(0.001)
+ .weightInit(WeightInit.XAVIER)
+ // .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
+ .learningRate(0.01)
+ .list()
+ .layer(0, new GravesLSTM.Builder().nIn(vectorSize).nOut(layerSize)
+ .activation(Activation.TANH).build())
+ .layer(1, new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
+ .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(layerSize).nOut(3).build())
+ .pretrain(false).backprop(true).build();
+
+ MultiLayerNetwork net = new MultiLayerNetwork(conf);
+ net.init();
+ net.setListeners(new ScoreIterationListener(5));
+
+ // TODO: Extract labels on the fly from the data
+
+ DataSetIterator train = new NameSampleDataSetIterator(samples, wordVectors, windowSize, labels);
+
+ System.out.println("Starting training");
+
+ for (int i = 0; i < epochs; i++) {
+ net.fit(train);
+ train.reset();
+ System.out.println(String.format("Finished epoche %d", i));
+ }
+
+ return net;
+ }
+
+ public static void main(String[] args) throws Exception {
+ if (args.length != 3) {
+ System.out.println("Usage: trainFile testFile gloveTxt");
+ return;
+ }
+
+ String[] labels = new String[] {
+ "default-start", "default-cont", "other"
+ };
+
+ System.out.print("Loading vectors ... ");
+ WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(
+ new File(args[2]));
+ System.out.println("Done");
+
+ int windowSize = 5;
+
+ MultiLayerNetwork net = train(wordVectors, new NameSampleDataStream(new PlainTextByLineStream(
+ new MarkableFileInputStreamFactory(new File(args[0])), StandardCharsets.UTF_8)), 1, windowSize, labels);
+
+ ObjectStream<NameSample> evalStream = new NameSampleDataStream(new PlainTextByLineStream(
+ new MarkableFileInputStreamFactory(
+ new File(args[1])), StandardCharsets.UTF_8));
+
+ NameFinderDL nameFinder = new NameFinderDL(net, wordVectors, windowSize, labels);
+
+ System.out.print("Evaluating ... ");
+ TokenNameFinderEvaluator nameFinderEvaluator = new TokenNameFinderEvaluator(nameFinder);
+ nameFinderEvaluator.evaluate(evalStream);
+
+ System.out.println("Done");
+
+ System.out.println();
+ System.out.println();
+ System.out.println("Results");
+
+ System.out.println(nameFinderEvaluator.getFMeasure().toString());
+ }
+}
diff --git a/opennlp-dl/src/main/java/NameSampleDataSetIterator.java b/opennlp-dl/src/main/java/NameSampleDataSetIterator.java
new file mode 100644
index 0000000..f416a1d
--- /dev/null
+++ b/opennlp-dl/src/main/java/NameSampleDataSetIterator.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+
+import opennlp.tools.namefind.NameSample;
+import opennlp.tools.util.FilterObjectStream;
+import opennlp.tools.util.ObjectStream;
+
+public class NameSampleDataSetIterator implements DataSetIterator {
+
+ private static class NameSampleToDataSetStream extends FilterObjectStream<NameSample, DataSet> {
+
+ private final WordVectors wordVectors;
+ private final String[] labels;
+ private int windowSize;
+
+ private Iterator<DataSet> dataSets = Collections.emptyListIterator();
+
+ NameSampleToDataSetStream(ObjectStream<NameSample> samples, WordVectors wordVectors, int windowSize, String[] labels) {
+ super(samples);
+ this.wordVectors = wordVectors;
+ this.windowSize = windowSize;
+ this.labels = labels;
+ }
+
+ private Iterator<DataSet> createDataSets(NameSample sample) {
+ List<INDArray> features = NameFinderDL.mapToFeatureMatrices(wordVectors, sample.getSentence(),
+ windowSize);
+
+ List<INDArray> labels = NameFinderDL.mapToLabelVectors(sample, windowSize, this.labels);
+
+ List<DataSet> dataSetList = new ArrayList<>();
+
+ for (int i = 0; i < features.size(); i++) {
+ dataSetList.add(new DataSet(features.get(i), labels.get(i)));
+ }
+
+ return dataSetList.iterator();
+ }
+
+ @Override
+ public final DataSet read() throws IOException {
+
+ if (dataSets.hasNext()) {
+ return dataSets.next();
+ }
+ else {
+ NameSample sample;
+ while (!dataSets.hasNext() && (sample = samples.read()) != null) {
+ dataSets = createDataSets(sample);
+ }
+
+ if (dataSets.hasNext()) {
+ return read();
+ }
+ }
+
+ return null;
+ }
+ }
+
+ private final int windowSize;
+ private final String[] labels;
+
+ private final int batchSize = 128;
+ private final int vectorSize = 300;
+
+ private final int totalSamples;
+
+ private int cursor = 0;
+
+ private final ObjectStream<DataSet> samples;
+
+ NameSampleDataSetIterator(ObjectStream<NameSample> samples, WordVectors wordVectors, int windowSize,
+ String labels[]) throws IOException {
+ this.windowSize = windowSize;
+ this.labels = labels;
+
+ this.samples = new NameSampleToDataSetStream(samples, wordVectors, windowSize, labels);
+
+ int total = 0;
+
+ DataSet sample;
+ while ((sample = this.samples.read()) != null) {
+ total++;
+ }
+
+ totalSamples = total;
+
+ samples.reset();
+ }
+
+ public DataSet next(int num) {
+ if (cursor >= totalExamples()) throw new NoSuchElementException();
+
+ INDArray features = Nd4j.create(num, vectorSize, windowSize);
+ INDArray featuresMask = Nd4j.zeros(num, windowSize);
+
+ INDArray labels = Nd4j.create(num, 3, windowSize);
+ INDArray labelsMask = Nd4j.zeros(num, windowSize);
+
+ // iterate stream and copy to arrays
+
+ for (int i = 0; i < num; i++) {
+ DataSet sample;
+ try {
+ sample = samples.read();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ if (sample != null) {
+ INDArray feature = sample.getFeatureMatrix();
+ features.put(new INDArrayIndex[] {NDArrayIndex.point(i)}, feature.get(NDArrayIndex.point(0)));
+
+ feature.get(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.all(),
+ NDArrayIndex.point(0)});
+
+ for (int j = 0; j < windowSize; j++) {
+ featuresMask.putScalar(new int[] {i, j}, 1.0);
+ }
+
+ INDArray label = sample.getLabels();
+ labels.put(new INDArrayIndex[] {NDArrayIndex.point(i)}, label.get(NDArrayIndex.point(0)));
+ labelsMask.putScalar(new int[] {i, windowSize - 1}, 1.0);
+ }
+
+ cursor++;
+ }
+
+ return new DataSet(features, labels, featuresMask, labelsMask);
+ }
+
+ public int totalExamples() {
+ return totalSamples;
+ }
+
+ public int inputColumns() {
+ return vectorSize;
+ }
+
+ public int totalOutcomes() {
+ return getLabels().size();
+ }
+
+ public boolean resetSupported() {
+ return true;
+ }
+
+ public boolean asyncSupported() {
+ return false;
+ }
+
+ public void reset() {
+ cursor = 0;
+
+ try {
+ samples.reset();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public int batch() {
+ return batchSize;
+ }
+
+ public int cursor() {
+ return cursor;
+ }
+
+ public int numExamples() {
+ return totalExamples();
+ }
+
+ public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
+ throw new UnsupportedOperationException();
+ }
+
+ public DataSetPreProcessor getPreProcessor() {
+ throw new UnsupportedOperationException();
+ }
+
+ public List<String> getLabels() {
+ return Arrays.asList("start","cont", "other");
+ }
+
+ public boolean hasNext() {
+ return cursor < numExamples();
+ }
+
+ public DataSet next() {
+ return next(batchSize);
+ }
+}