Merge pull request #3 from thammegowda/glove-rnn-classifier
text sequence classification using Glove and RNN/LSTMs
diff --git a/.gitignore b/.gitignore
index fe06e66..126d4a6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,6 @@
nbactions.xml
nb-configuration.xml
*.DS_Store
+
+.idea
+*.iml
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index 3d15d8f..9899efe 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -26,7 +26,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
- <nd4j.version>0.7.2</nd4j.version>
+ <nd4j.version>0.8.0</nd4j.version>
</properties>
<dependencies>
@@ -41,8 +41,6 @@
<artifactId>deeplearning4j-core</artifactId>
<version>${nd4j.version}</version>
</dependency>
-
-
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
@@ -64,6 +62,11 @@
<artifactId>nd4j-native-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
+ <dependency>
+ <groupId>args4j</groupId>
+ <artifactId>args4j</artifactId>
+ <version>2.33</version>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
new file mode 100644
index 0000000..86af123
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.FileUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Function;
+
+/**
+ * This class provides a reader capable of reading training and test datasets from file system for text classifiers.
+ * In addition to reading the content, it
+ * (1) vectorizes the text using embeddings such as Glove, and
+ * (2) divides the datasets into mini batches of specified size.
+ *
+ * The data is expected to be organized as per the following convention:
+ * <pre>
+ * data-dir/
+ * +- label1 /
+ * | +- example11.txt
+ * | +- example12.txt
+ * | +- example13.txt
+ * | +- .....
+ * +- label2 /
+ * | +- example21.txt
+ * | +- .....
+ * +- labelN /
+ * +- exampleN1.txt
+ * +- .....
+ * </pre>
+ *
+ * In addition, the dataset shall be divided into training and testing as follows:
+ * <pre>
+ * data-dir/
+ * + train/
+ * | +- label1 /
+ * | +- labelN /
+ * + test /
+ * +- label1 /
+ * +- labelN /
+ * </pre>
+ *
+ * <h2>Usage: </h2>
+ * <code>
+ * // label names should match the subdirectory names
+ * labels = Arrays.asList("label1", "label2", ..."labelN");
+ * train = DataReader('data-dir/train', labels, embeds, ....);
+ * test = DataReader('data-dir/test', labels, embeds, ....)
+ * </code>
+ *
+ * @see GlobalVectors
+ * @see NeuralDocCat
+ * <br/>
+ * @author Thamme Gowda (thammegowda@apache.org)
+ *
+ */
+public class DataReader implements DataSetIterator {
+
+ private static final Logger LOG = LoggerFactory.getLogger(DataReader.class);
+
+ private File dataDir;
+ private List<File> records;
+ private List<Integer> labels;
+ private Map<String, Integer> labelToId;
+ private String extension = ".txt";
+ private GlobalVectors embedder;
+ private int cursor = 0;
+ private int batchSize;
+ private int vectorLen;
+ private int maxSeqLen;
+ private int numLabels;
+ // default tokenizer
+ private Function<String, String[]> tokenizer = s -> s.toLowerCase().split(" ");
+
+
+ /**
+ * Creates a reader with the specified arguments
+ * @param dataDirPath data directory
+ * @param labelNames list of labels (names should match sub directory names)
+ * @param embedder embeddings to convert words to vectors
+ * @param batchSize mini batch size for DL4j training
+ * @param maxSeqLength truncate sequences that are longer than this.
+ * If truncation is not desired, set {@code Integer.MAX_VAL}
+ */
+ DataReader(String dataDirPath, List<String> labelNames, GlobalVectors embedder,
+ int batchSize, int maxSeqLength){
+ this.batchSize = batchSize;
+ this.embedder = embedder;
+ this.maxSeqLen = maxSeqLength;
+ this.vectorLen = embedder.getVectorSize();
+ this.numLabels = labelNames.size();
+ this.dataDir = new File(dataDirPath);
+ this.labelToId = new HashMap<>();
+ for (int i = 0; i < labelNames.size(); i++) {
+ labelToId.put(labelNames.get(i), i);
+ }
+ this.labelToId = Collections.unmodifiableMap(this.labelToId);
+ this.scanDir();
+ this.reset();
+ }
+
+ private void scanDir(){
+ assert dataDir.exists();
+ List<Integer> labels = new ArrayList<>();
+ List<File> files = new ArrayList<>();
+ for (String labelName: this.labelToId.keySet()) {
+ Integer labelId = this.labelToId.get(labelName);
+ assert labelId != null;
+ File labelDir = new File(dataDir, labelName);
+ if (!labelDir.exists()){
+ throw new IllegalStateException("No examples found for "
+ + labelName + ". Looked at:" + labelDir);
+ }
+ File[] examples = labelDir.listFiles(f ->
+ f.isFile() && f.getName().endsWith(this.extension));
+ if (examples == null || examples.length == 0){
+ throw new IllegalStateException("No examples found for "
+ + labelName + ". Looked at:" + labelDir
+ + " for files having extension: \" + extension");
+ }
+ LOG.info("Found {} examples for label {}", examples.length, labelName);
+ for (File example: examples) {
+ files.add(example);
+ labels.add(labelId);
+ }
+ }
+ this.records = files;
+ this.labels = labels;
+ }
+
+ /**
+ * sets tokenizer for converting text to tokens
+ * @param tokenizer tokenizer to use for converting text to tokens
+ */
+ public void setTokenizer(Function<String, String[]> tokenizer) {
+ this.tokenizer = tokenizer;
+ }
+
+ /**
+ * @return Tokenizer function used for converting text into words
+ */
+ public Function<String, String[]> getTokenizer() {
+ return tokenizer;
+ }
+
+ @Override
+ public DataSet next(int batchSize) {
+ batchSize = Math.min(batchSize, records.size() - cursor);
+ INDArray features = Nd4j.create(batchSize, vectorLen, maxSeqLen);
+ INDArray labels = Nd4j.create(batchSize, numLabels, maxSeqLen);
+
+ //Because we are dealing with text of different lengths and only one output at the final time step: use padding arrays
+ //Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
+ INDArray featuresMask = Nd4j.zeros(batchSize, maxSeqLen);
+ INDArray labelsMask = Nd4j.zeros(batchSize, maxSeqLen);
+
+ // Optimizations to speed up this code block by reusing memory
+ int _2dIndex[] = new int[2];
+ int _3dIndex[] = new int[3];
+ INDArrayIndex _3dNdIndex[] = new INDArrayIndex[]{null, NDArrayIndex.all(), null};
+
+ for (int i = 0; i < batchSize && cursor < records.size(); i++, cursor++) {
+ _2dIndex[0] = i;
+ _3dIndex[0] = i;
+ _3dNdIndex[0] = NDArrayIndex.point(i);
+
+ try {
+ // Read
+ File file = records.get(cursor);
+ int labelIdx = this.labels.get(cursor);
+ String text = FileUtils.readFileToString(file);
+ // Tokenize and Filter
+ String[] tokens = tokenizer.apply(text);
+ tokens = Arrays.stream(tokens).filter(embedder::hasWord).toArray(String[]::new);
+ //Get word vectors for each word in review, and put them in the training data
+ int j;
+ for(j = 0; j < tokens.length && j < maxSeqLen; j++ ){
+ String token = tokens[j];
+ INDArray vector = embedder.toVector(token);
+ _3dNdIndex[2] = NDArrayIndex.point(j);
+ features.put(_3dNdIndex, vector);
+ //Word is present (not padding) for this example + time step -> 1.0 in features mask
+ _2dIndex[1] = j;
+ featuresMask.putScalar(_2dIndex, 1.0);
+ }
+ int lastIdx = j - 1;
+ _2dIndex[1] = lastIdx;
+ _3dIndex[1] = labelIdx;
+ _3dIndex[2] = lastIdx;
+
+ labels.putScalar(_3dIndex,1.0); //Set label: one of k encoding
+ // Specify that an output exists at the final time step for this example
+ labelsMask.putScalar(_2dIndex,1.0);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ //LOG.info("Cursor = {} || Init Time = {}, Read time = {}, preprocess Time = {}, Mask Time={}", cursor, initTime, readTime, preProcTime, maskTime);
+ return new DataSet(features, labels, featuresMask, labelsMask);
+ }
+
+ @Override
+ public int totalExamples() {
+ return this.records.size();
+ }
+
+ @Override
+ public int inputColumns() {
+ return this.embedder.getVectorSize();
+ }
+
+ @Override
+ public int totalOutcomes() {
+ return this.numLabels;
+ }
+
+ @Override
+ public boolean resetSupported() {
+ return true;
+ }
+
+ @Override
+ public boolean asyncSupported() {
+ return false;
+ }
+
+ @Override
+ public void reset() {
+ assert this.records.size() == this.labels.size();
+ long seed = System.nanoTime(); // shuffle both the lists in the same order
+ Collections.shuffle(this.records, new Random(seed));
+ Collections.shuffle(this.labels, new Random(seed));
+ this.cursor = 0; // from beginning
+ }
+
+ @Override
+ public int batch() {
+ return this.batchSize;
+ }
+
+ @Override
+ public int cursor() {
+ return this.cursor;
+ }
+
+ @Override
+ public int numExamples() {
+ return totalExamples();
+ }
+
+ @Override
+ public void setPreProcessor(DataSetPreProcessor preProcessor) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataSetPreProcessor getPreProcessor() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List<String> getLabels() {
+ return new ArrayList<>(this.labelToId.keySet());
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < totalExamples() - 1;
+ }
+
+ @Override
+ public DataSet next() {
+ return next(this.batchSize);
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
new file mode 100644
index 0000000..fdf3a95
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.IOUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * GlobalVectors (Glove) for projecting words to vector space.
+ * This tool utilizes word vectors pre-trained on large datasets.
+ *
+ * Visit https://nlp.stanford.edu/projects/glove/ for full documentation of Gloves.
+ *
+ * <h2>Usage</h2>
+ * <pre>
+ * path = "work/datasets/glove.6B/glove.6B.100d.txt";
+ * vocabSize = 20000; # max number of words to use
+ * GlobalVectors glove;
+ * try (InputStream stream = new FileInputStream(path)) {
+ * glove = new GlobalVectors(stream, vocabSize);
+ * }
+ * </pre>
+ *
+ * @author Thamme Gowda (thammegowda@apache.org)
+ *
+ */
+public class GlobalVectors {
+
+ private static final Logger LOG = LoggerFactory.getLogger(GlobalVectors.class);
+
+ private final INDArray embeddings;
+ private final Map<String, Integer> wordToId;
+ private final List<String> idToWord;
+ private final int vectorSize;
+ private final int maxWords;
+
+ /**
+ * Reads Global Vectors from stream
+ * @param stream Glove word vectors stream (plain text)
+ * @throws IOException
+ */
+ public GlobalVectors(InputStream stream) throws IOException {
+ this(stream, Integer.MAX_VALUE);
+ }
+
+ /**
+ *
+ * @param stream vector stream
+ * @param maxWords maximum number of words to use, i.e. vocabulary size
+ * @throws IOException
+ */
+ public GlobalVectors(InputStream stream, int maxWords) throws IOException {
+ List<String> words = new ArrayList<>();
+ List<INDArray> vectors = new ArrayList<>();
+ int vectorSize = -1;
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))){
+ String line;
+ while ((line = reader.readLine()) != null) {
+ String[] parts = line.split(" ");
+ if (vectorSize == -1) {
+ vectorSize = parts.length - 1;
+ } else {
+ assert vectorSize == parts.length - 1;
+ }
+ float[] vector = new float[vectorSize];
+ for (int i = 1; i < parts.length; i++) {
+ vector[i-1] = Float.parseFloat(parts[i]);
+ }
+ vectors.add(Nd4j.create(vector));
+ words.add(parts[0]);
+ if (words.size() >= maxWords) {
+ LOG.info("Max words reached at {}, aborting", words.size());
+ break;
+ }
+ }
+ LOG.info("Found {} words; Vector dimensions={}", words.size(), vectorSize);
+ this.vectorSize = vectorSize;
+ this.maxWords = Math.min(words.size(), maxWords);
+ this.embeddings = Nd4j.create(vectors, new int[]{vectors.size(), vectorSize});
+ this.idToWord = words;
+ this.wordToId = new HashMap<>();
+ for (int i = 0; i < words.size(); i++) {
+ wordToId.put(words.get(i), i);
+ }
+ }
+ }
+
+ /**
+ * @return size or dimensions of vectors
+ */
+ public int getVectorSize() {
+ return vectorSize;
+ }
+
+ public int getMaxWords() {
+ return maxWords;
+ }
+
+ /**
+ *
+ * @param word
+ * @return {@code true} if word is known; false otherwise
+ */
+ public boolean hasWord(String word){
+ return wordToId.containsKey(word);
+ }
+
+ /**
+ * Converts word to vectors
+ * @param word word to be converted to vector
+ * @return Vector if words exists or null otherwise
+ */
+ public INDArray toVector(String word){
+ if (wordToId.containsKey(word)){
+ return embeddings.getRow(wordToId.get(word));
+ }
+ return null;
+ }
+
+ public INDArray embed(String text, int maxLen){
+ return embed(text.toLowerCase().split(" "), maxLen);
+ }
+
+ public INDArray embed(String[] tokens, int maxLen){
+ List<String> tokensFiltered = new ArrayList<>();
+ for(String t: tokens ){
+ if(hasWord(t)){
+ tokensFiltered.add(t);
+ }
+ }
+ int seqLen = Math.min(maxLen, tokensFiltered.size());
+
+ INDArray features = Nd4j.create(1, vectorSize, seqLen);
+
+ for( int j = 0; j < seqLen; j++ ){
+ String token = tokensFiltered.get(j);
+ INDArray vector = toVector(token);
+ features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
+ }
+ return features;
+ }
+
+ public void writeOut(OutputStream stream, boolean closeStream) throws IOException {
+ writeOut(stream, "%.5f", closeStream);
+ }
+
+ public void writeOut(OutputStream stream,
+ String floatPrecisionFormatString, boolean closeStream) throws IOException {
+ if (!Character.isWhitespace(floatPrecisionFormatString.charAt(0))) {
+ floatPrecisionFormatString = " " + floatPrecisionFormatString;
+ }
+ LOG.info("Writing {} vectors out, float precision {}", idToWord.size(), floatPrecisionFormatString);
+
+ PrintWriter out = new PrintWriter(stream);
+ try {
+ for (int i = 0; i < idToWord.size(); i++) {
+ out.printf("%s", idToWord.get(i));
+ INDArray row = embeddings.getRow(i);
+ for (int j = 0; j < vectorSize; j++) {
+ out.printf(floatPrecisionFormatString, row.getDouble(j));
+ }
+ out.println();
+ }
+ } finally {
+ if (closeStream){
+ IOUtils.closeQuietly(out);
+ } // else dont close because, closing the print writer also closes the inner stream
+ }
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
new file mode 100644
index 0000000..53bf530
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import opennlp.tools.doccat.DocumentCategorizer;
+import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WhitespaceTokenizer;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.NotImplementedException;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+
+/**
+ * An implementation of {@link DocumentCategorizer} using Neural Networks.
+ * This class provides prediction functionality from the model of {@link NeuralDocCatTrainer}.
+ *
+ */
+public class NeuralDocCat implements DocumentCategorizer {
+
+ private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCat.class);
+
+ private NeuralDocCatModel model;
+
+ public NeuralDocCat(NeuralDocCatModel model) {
+ this.model = model;
+ }
+
+ @Override
+ public double[] categorize(String[] tokens) {
+ return categorize(tokens, Collections.emptyMap());
+ }
+
+ @Override
+ public double[] categorize(String[] text, Map<String, Object> extraInformation) {
+ INDArray seqFeatures = this.model.getGloves().embed(text, this.model.getMaxSeqLen());
+
+ INDArray networkOutput = this.model.getNetwork().output(seqFeatures);
+ int timeSeriesLength = networkOutput.size(2);
+ INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
+ NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
+
+ int nLabels = this.model.getLabels().size();
+ double[] probs = new double[nLabels];
+ for (int i = 0; i < nLabels; i++) {
+ probs[i] = probsAtLastWord.getDouble(i);
+ }
+ return probs;
+ }
+
+ @Override
+ public String getBestCategory(double[] outcome) {
+ int maxIdx = 0;
+ double maxProb = outcome[0];
+ for (int i = 1; i < outcome.length; i++) {
+ if (outcome[i] > maxProb) {
+ maxIdx = i;
+ maxProb = outcome[i];
+ }
+ }
+ return model.getLabels().get(maxIdx);
+ }
+
+ @Override
+ public int getIndex(String category) {
+ return model.getLabels().indexOf(category);
+ }
+
+ @Override
+ public String getCategory(int index) {
+ return model.getLabels().get(index);
+ }
+
+ @Override
+ public int getNumberOfCategories() {
+ return model.getLabels().size();
+ }
+
+
+ @Override
+ public String getAllResults(double[] results) {
+ throw new NotImplementedException("Not implemented");
+ }
+
+ @Override
+ public Map<String, Double> scoreMap(String[] text) {
+ double[] scores = categorize(text);
+ Map<String, Double> result = new HashMap<>();
+ for (int i = 0; i < scores.length; i++) {
+ result.put(model.getLabels().get(i), scores[i]);
+
+ }
+ return result;
+ }
+
+ @Override
+ public SortedMap<Double, Set<String>> sortedScoreMap(String[] text) {
+ throw new NotImplementedException("Not implemented");
+ }
+
+ @Override
+ @Deprecated
+ public double[] categorize(String documentText) {
+ throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+ }
+
+ @Override
+ @Deprecated
+ public Map<String, Double> scoreMap(String text) {
+ throw new UnsupportedOperationException("Use the other scoreMap(..) method that accepts tokenized text");
+ }
+
+ @Override
+ @Deprecated
+ public SortedMap<Double, Set<String>> sortedScoreMap(String text) {
+ throw new UnsupportedOperationException("Use the other sortedScoreMap(..) method that accepts tokenized text");
+ }
+ @Override
+ @Deprecated
+ public double[] categorize(String documentText, Map<String, Object> extraInformation) {
+ throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+ }
+
+
+ public static void main(String[] argss) throws CmdLineException, IOException {
+ class Args {
+
+ @Option(name = "-model", required = true, usage = "Path to NeuralDocCatModel stored file")
+ String modelPath;
+
+ @Option(name = "-files", required = true, usage = "One or more document paths whose category is " +
+ "to be predicted by the model")
+ List<File> files;
+ }
+
+ Args args = new Args();
+ CmdLineParser parser = new CmdLineParser(args);
+ try {
+ parser.parseArgument(argss);
+ } catch (CmdLineException e) {
+ System.out.println(e.getMessage());
+ e.getParser().printUsage(System.out);
+ System.exit(1);
+ }
+
+ NeuralDocCatModel model = NeuralDocCatModel.loadModel(args.modelPath);
+ NeuralDocCat classifier = new NeuralDocCat(model);
+
+ System.out.println("Labels:" + model.getLabels());
+ Tokenizer tokenizer = WhitespaceTokenizer.INSTANCE;
+
+ for (File file: args.files) {
+ String text = FileUtils.readFileToString(file);
+ String[] tokens = tokenizer.tokenize(text.toLowerCase());
+ double[] probs = classifier.categorize(tokens);
+ System.out.println(">>" + file);
+ System.out.println("Probabilities:" + Arrays.toString(probs));
+ }
+
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
new file mode 100644
index 0000000..f1b6247
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
@@ -0,0 +1,179 @@
+package opennlp.tools.dl;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.*;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipInputStream;
+import java.util.zip.ZipOutputStream;
+
+/**
+ * This class is a wrapper for DL4J's {@link MultiLayerNetwork}, and {@link GlobalVectors}
+ * that provides features to serialize and deserialize necessary data to a zip file.
+ *
+ * This cane be used by a Neural Trainer tool to serialize the network and a predictor tool to restore the same network
+ * with the weights.
+ *
+ * <br/>
+ ** @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatModel {
+
+ public static final int VERSION = 1;
+ public static final String MODEL_NAME = NeuralDocCatModel.class.getName();
+ public static final String MANIFEST = "model.mf";
+ public static final String NETWORK = "network.json";
+ public static final String WEIGHTS = "weights.bin";
+ public static final String GLOVES = "gloves.tsv";
+ public static final String LABELS = "labels";
+ public static final String MAX_SEQ_LEN = "maxSeqLen";
+
+ private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatModel.class);
+
+ private final MultiLayerNetwork network;
+ private final GlobalVectors gloves;
+ private final Properties manifest;
+ private final List<String> labels;
+ private final int maxSeqLen;
+
+ /**
+ *
+ * @param stream Input stream of a Zip File
+ * @throws IOException
+ */
+ public NeuralDocCatModel(InputStream stream) throws IOException {
+ ZipInputStream zipIn = new ZipInputStream(stream);
+
+ Properties manifest = null;
+ MultiLayerNetwork model = null;
+ INDArray params = null;
+ GlobalVectors gloves = null;
+ ZipEntry entry;
+ while ((entry = zipIn.getNextEntry()) != null) {
+ String name = entry.getName();
+ switch (name) {
+ case MANIFEST:
+ manifest = new Properties();
+ manifest.load(zipIn);
+ break;
+ case NETWORK:
+ String json = IOUtils.toString(new UnclosableInputStream(zipIn));
+ model = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
+ break;
+ case WEIGHTS:
+ params = Nd4j.read(new DataInputStream(new UnclosableInputStream(zipIn)));
+ break;
+ case GLOVES:
+ gloves = new GlobalVectors(new UnclosableInputStream(zipIn));
+ break;
+ default:
+ LOG.warn("Unexpected entry in the zip : {}", name);
+ }
+ }
+
+ assert model != null;
+ assert manifest != null;
+ model.init(params, false);
+ this.network = model;
+ this.manifest = manifest;
+ this.gloves = gloves;
+
+ assert manifest.containsKey(LABELS);
+ String[] labels = manifest.getProperty(LABELS).split(",");
+ this.labels = Collections.unmodifiableList(Arrays.asList(labels));
+
+ assert manifest.containsKey(MAX_SEQ_LEN);
+ this.maxSeqLen = Integer.parseInt(manifest.getProperty(MAX_SEQ_LEN));
+
+ }
+
+ /**
+ *
+ * @param network any compatible multi layer neural network
+ * @param vectors Global vectors
+ * @param labels list of labels
+ * @param maxSeqLen max sequence length
+ */
+ public NeuralDocCatModel(MultiLayerNetwork network, GlobalVectors vectors, List<String> labels, int maxSeqLen) {
+ this.network = network;
+ this.gloves = vectors;
+ this.manifest = new Properties();
+ this.manifest.setProperty(LABELS, StringUtils.join(labels, ","));
+ this.manifest.setProperty(MAX_SEQ_LEN, maxSeqLen + "");
+ this.labels = Collections.unmodifiableList(labels);
+ this.maxSeqLen = maxSeqLen;
+ }
+
+ public MultiLayerNetwork getNetwork() {
+ return network;
+ }
+
+ public GlobalVectors getGloves() {
+ return gloves;
+ }
+
+ public List<String> getLabels() {
+ return labels;
+ }
+
+ public int getMaxSeqLen() {
+ return this.maxSeqLen;
+ }
+
+ /**
+ * Zips the current state of the model and writes it stream
+ * @param stream stream to write
+ * @throws IOException
+ */
+ public void saveModel(OutputStream stream) throws IOException {
+ try (ZipOutputStream zipOut = new ZipOutputStream(new BufferedOutputStream(stream))) {
+ // Write out manifest
+ zipOut.putNextEntry(new ZipEntry(MANIFEST));
+
+ String comments = "Created-By:" + System.getenv("USER") + " at " + new Date().toString()
+ + "\nModel-Version: " + VERSION
+ + "\nModel-Schema:" + MODEL_NAME;
+
+ manifest.store(zipOut, comments);
+ zipOut.closeEntry();
+
+ // Write out the network
+ zipOut.putNextEntry(new ZipEntry(NETWORK));
+ byte[] jModel = network.getLayerWiseConfigurations().toJson().getBytes();
+ zipOut.write(jModel);
+ zipOut.closeEntry();
+
+ //Write out the network coefficients
+ zipOut.putNextEntry(new ZipEntry(WEIGHTS));
+ Nd4j.write(network.params(), new DataOutputStream(zipOut));
+ zipOut.closeEntry();
+
+ // Write out vectors
+ zipOut.putNextEntry(new ZipEntry(GLOVES));
+ gloves.writeOut(zipOut, false);
+ zipOut.closeEntry();
+
+ zipOut.finish();
+ }
+ }
+
+ /**
+ * creates a model from file on the local file system
+ * @param modelPath path to model file
+ * @return an instance of this class
+ * @throws IOException
+ */
+ public static NeuralDocCatModel loadModel(String modelPath) throws IOException {
+ try (InputStream modelStream = new FileInputStream(modelPath)) {
+ return new NeuralDocCatModel(modelStream);
+ }
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
new file mode 100644
index 0000000..4099b65
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
@@ -0,0 +1,253 @@
+package opennlp.tools.dl;
+
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.kohsuke.args4j.spi.StringArrayOptionHandler;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.List;
+
+
+/**
+ * This class provides functionality to construct and train neural networks that can be used for
+ * {@link opennlp.tools.doccat.DocumentCategorizer}
+ *
+ * @see NeuralDocCat
+ * @see NeuralDocCatModel
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatTrainer {
+
+ public static class Args {
+
+ @Option(name = "-batchSize", usage = "Number of examples in minibatch")
+ int batchSize = 128;
+
+ @Option(name = "-nEpochs", usage = "Number of epochs (i.e. full passes over the training data) to train on." +
+ " Applicable for training only.")
+ int nEpochs = 2;
+
+ @Option(name = "-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
+ int maxSeqLen = 256; //Truncate text with length (# words) greater than this
+
+ @Option(name = "-vocabSize", usage = "Vocabulary Size.")
+ int vocabSize = 20000; //vocabulary size
+
+ @Option(name = "-nRNNUnits", usage = "Number of RNN cells to use.")
+ int nRNNUnits = 128;
+
+ @Option(name = "-lr", aliases = "-learnRate", usage = "Learning Rate." +
+ " Adjust it when the scores bounce to NaN or Infinity.")
+ double learningRate = 2e-3;
+
+ @Option(name = "-glovesPath", required = true, usage = "Path to GloVe vectors file." +
+ " Download and unzip from https://nlp.stanford.edu/projects/glove/")
+ String glovesPath = null;
+
+ @Option(name = "-modelPath", required = true, usage = "Path to model file. " +
+ "This will be used for serializing the model after the training phase." )
+ String modelPath = null;
+
+ @Option(name = "-trainDir", required = true, usage = "Path to train data directory." +
+ " Setting this value will take the system to training mode. ")
+ String trainDir = null;
+
+ @Option(name = "-validDir", usage = "Path to validation data directory. Optional.")
+ String validDir = null;
+
+ @Option(name = "-labels", required = true, handler = StringArrayOptionHandler.class,
+ usage = "Names of targets or labels separated by spaces. " +
+ "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
+ "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
+ "applicable. \n Example -labels pos neg")
+ List<String> labels = null;
+
+ @Override
+ public String toString() {
+ return "Args{" +
+ "batchSize=" + batchSize +
+ ", nEpochs=" + nEpochs +
+ ", maxSeqLen=" + maxSeqLen +
+ ", vocabSize=" + vocabSize +
+ ", learningRate=" + learningRate +
+ ", nRNNUnits=" + nRNNUnits +
+ ", glovesPath='" + glovesPath + '\'' +
+ ", modelPath='" + modelPath + '\'' +
+ ", trainDir='" + trainDir + '\'' +
+ ", validDir='" + validDir + '\'' +
+ ", labels=" + labels +
+ '}';
+ }
+ }
+
+ private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatTrainer.class);
+
+ private NeuralDocCatModel model;
+ private Args args;
+ private DataReader trainSet;
+ private DataReader validSet;
+
+
+ public NeuralDocCatTrainer(Args args) throws IOException {
+ this.args = args;
+ GlobalVectors gloves;
+ MultiLayerNetwork network;
+
+ try (InputStream stream = new FileInputStream(args.glovesPath)) {
+ gloves = new GlobalVectors(stream, args.vocabSize);
+ }
+
+ LOG.info("Training data from {}", args.trainDir);
+ this.trainSet = new DataReader(args.trainDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+ if (args.validDir != null) {
+ LOG.info("Validation data from {}", args.validDir);
+ this.validSet = new DataReader(args.validDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+ }
+
+ //create network
+ network = this.createNetwork(gloves.getVectorSize());
+ this.model = new NeuralDocCatModel(network, gloves, args.labels, args.maxSeqLen);
+ }
+
+ public MultiLayerNetwork createNetwork(int vectorSize) {
+ int totalOutcomes = this.trainSet.totalOutcomes();
+ assert totalOutcomes >= 2;
+ LOG.info("Number of classes " + totalOutcomes);
+
+ //TODO: the below network params should be configurable from CLI or settings file
+ //Set up network configuration
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
+ .rmsDecay(0.9)
+ .regularization(true).l2(1e-5)
+ .weightInit(WeightInit.XAVIER)
+ .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+ .gradientNormalizationThreshold(1.0)
+ .learningRate(args.learningRate)
+ .list()
+ .layer(0, new GravesLSTM.Builder()
+ .nIn(vectorSize)
+ .nOut(args.nRNNUnits)
+ .activation(Activation.RELU).build())
+ .layer(1, new RnnOutputLayer.Builder()
+ .nIn(args.nRNNUnits)
+ .nOut(totalOutcomes)
+ .activation(Activation.SOFTMAX)
+ .lossFunction(LossFunctions.LossFunction.MCXENT)
+ .build())
+ .pretrain(false)
+ .backprop(true)
+ .build();
+
+ MultiLayerNetwork net = new MultiLayerNetwork(conf);
+ net.init();
+ net.setListeners(new ScoreIterationListener(1));
+ return net;
+ }
+
+ public void train() {
+ train(args.nEpochs, this.trainSet, this.validSet);
+ }
+
+ /**
+ * Trains model
+ *
+ * @param nEpochs number of epochs (i.e. iterations over the training dataset)
+ * @param train training data set
+ * @param validation validation data set for evaluation after each epoch.
+ * Setting this to null will skip the evaluation
+ */
+ public void train(int nEpochs, DataReader train, DataReader validation) {
+ assert model != null;
+ assert train != null;
+ LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
+ train.totalExamples(), validation == null ? null : validation.totalExamples());
+ for (int i = 0; i < nEpochs; i++) {
+ model.getNetwork().fit(train);
+ train.reset();
+ LOG.info("Epoch {} complete", i);
+
+ if (validation != null) {
+ LOG.info("Starting evaluation");
+ //Run evaluation. This is on 25k reviews, so can take some time
+ Evaluation evaluation = new Evaluation();
+ while (validation.hasNext()) {
+ DataSet t = validation.next();
+ INDArray features = t.getFeatureMatrix();
+ INDArray labels = t.getLabels();
+ INDArray inMask = t.getFeaturesMaskArray();
+ INDArray outMask = t.getLabelsMaskArray();
+ INDArray predicted = this.model.getNetwork().output(features, false, inMask, outMask);
+ evaluation.evalTimeSeries(labels, predicted, outMask);
+ }
+ validation.reset();
+ LOG.info(evaluation.stats());
+ }
+ }
+ }
+
+ /**
+ * Saves the model to specified path
+ *
+ * @param path model path
+ * @throws IOException
+ */
+ public void saveModel(String path) throws IOException {
+ assert model != null;
+ LOG.info("Saving the model at {}", path);
+ try (OutputStream stream = new FileOutputStream(path)) {
+ model.saveModel(stream);
+ }
+ }
+
+ /**
+ * <pre>
+ * # Download pre trained Glo-ves (this is a large file)
+ * wget http://nlp.stanford.edu/data/glove.6B.zip
+ * unzip glove.6B.zip -d glove.6B
+ *
+ * # Download dataset
+ * wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+ * tar xzf aclImdb_v1.tar.gz
+ *
+ * mvn compile exec:java
+ * -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.NeuralDocCat
+ * -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
+ * -labels pos neg -modelPath imdb-sentiment-neural-model.zip
+ * -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
+ *
+ * </pre>
+ */
+ public static void main(String[] argss) throws CmdLineException, IOException {
+ Args args = new Args();
+ CmdLineParser parser = new CmdLineParser(args);
+ try {
+ parser.parseArgument(argss);
+ } catch (CmdLineException e) {
+ System.out.println(e.getMessage());
+ e.getParser().printUsage(System.out);
+ System.exit(1);
+ }
+ NeuralDocCatTrainer classifier = new NeuralDocCatTrainer(args);
+ classifier.train();
+ classifier.saveModel(args.modelPath);
+ }
+
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
new file mode 100644
index 0000000..701fc48
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
@@ -0,0 +1,56 @@
+package opennlp.tools.dl;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.Reader;
+import java.io.Writer;
+
+/**
+ * This class offers a wrapper for {@link InputStream};
+ * The only sole purpose of this wrapper is to bypass the close calls that are usually
+ * propagated from the readers.
+ * A use case of this wrapper is for reading multiple files from the {@link java.util.zip.ZipInputStream},
+ * especially because the tools like {@link org.apache.commons.io.IOUtils#copy(Reader, Writer)}
+ * and {@link org.nd4j.linalg.factory.Nd4j#read(InputStream)} automatically close the input stream.
+ *
+ * Note:
+ * 1. this tool ignores the call to {@link #close()} method
+ * 2. Remember to call {@link #forceClose()} when the stream when the inner stream needs to be closed
+ * 3. This wrapper doesn't hold any resources. If you close the innerStream, you can safely ignore closing this wrapper
+ *
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class UnclosableInputStream extends InputStream {
+
+ private InputStream innerStream;
+
+ public UnclosableInputStream(InputStream stream){
+ this.innerStream = stream;
+ }
+
+ @Override
+ public int read() throws IOException {
+ return innerStream.read();
+ }
+
+ /**
+ * NOP - Does not close the stream - intentional
+ * @throws IOException
+ */
+ @Override
+ public void close() throws IOException {
+ // intentionally ignored;
+ // Use forceClose() when needed to close
+ }
+
+ /**
+ * Closes the stream
+ * @throws IOException
+ */
+ public void forceClose() throws IOException {
+ if (innerStream != null) {
+ innerStream.close();
+ innerStream = null;
+ }
+ }
+}