Merge pull request #3 from thammegowda/glove-rnn-classifier

text sequence classification using Glove and RNN/LSTMs
diff --git a/.gitignore b/.gitignore
index fe06e66..126d4a6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,6 @@
 nbactions.xml
 nb-configuration.xml
 *.DS_Store
+
+.idea
+*.iml
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index 3d15d8f..9899efe 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -26,7 +26,7 @@
 
   <properties>
     <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
-    <nd4j.version>0.7.2</nd4j.version>
+    <nd4j.version>0.8.0</nd4j.version>
   </properties>
 
   <dependencies>
@@ -41,8 +41,6 @@
           <artifactId>deeplearning4j-core</artifactId>
           <version>${nd4j.version}</version>
       </dependency>
-
-
       <dependency>
           <groupId>org.deeplearning4j</groupId>
           <artifactId>deeplearning4j-nlp</artifactId>
@@ -64,6 +62,11 @@
       <artifactId>nd4j-native-platform</artifactId>
       <version>${nd4j.version}</version>
     </dependency>
+    <dependency>
+      <groupId>args4j</groupId>
+      <artifactId>args4j</artifactId>
+      <version>2.33</version>
+    </dependency>
   </dependencies>
   <build>
     <plugins>
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
new file mode 100644
index 0000000..86af123
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.FileUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Function;
+
+/**
+ * This class provides a reader capable of reading training and test datasets from file system for text classifiers.
+ * In addition to reading the content, it
+ * (1) vectorizes the text using embeddings such as Glove, and
+ * (2) divides the datasets into mini batches of specified size.
+ *
+ * The data is expected to be organized as per the following convention:
+ * <pre>
+ * data-dir/
+ *     +- label1 /
+ *     |    +- example11.txt
+ *     |    +- example12.txt
+ *     |    +- example13.txt
+ *     |    +- .....
+ *     +- label2 /
+ *     |    +- example21.txt
+ *     |    +- .....
+ *     +- labelN /
+ *          +- exampleN1.txt
+ *          +- .....
+ * </pre>
+ *
+ * In addition, the dataset shall be divided into training and testing as follows:
+ * <pre>
+ * data-dir/
+ *     + train/
+ *     |   +- label1 /
+ *     |   +- labelN /
+ *     + test /
+ *         +- label1 /
+ *         +- labelN /
+ * </pre>
+ *
+ * <h2>Usage: </h2>
+ * <code>
+ *     // label names should match the subdirectory names
+ *     labels = Arrays.asList("label1", "label2", ..."labelN");
+ *     train = DataReader('data-dir/train', labels, embeds, ....);
+ *     test = DataReader('data-dir/test', labels, embeds, ....)
+ * </code>
+ *
+ * @see GlobalVectors
+ * @see NeuralDocCat
+ * <br/>
+ * @author Thamme Gowda (thammegowda@apache.org)
+ *
+ */
+public class DataReader implements DataSetIterator {
+
+    private static final Logger LOG = LoggerFactory.getLogger(DataReader.class);
+
+    private File dataDir;
+    private List<File> records;
+    private List<Integer> labels;
+    private Map<String, Integer> labelToId;
+    private String extension = ".txt";
+    private GlobalVectors embedder;
+    private int cursor = 0;
+    private int batchSize;
+    private int vectorLen;
+    private int maxSeqLen;
+    private int numLabels;
+    // default tokenizer
+    private Function<String, String[]> tokenizer = s -> s.toLowerCase().split(" ");
+
+
+    /**
+     * Creates a reader with the specified arguments
+     * @param dataDirPath data directory
+     * @param labelNames list of labels (names should match sub directory names)
+     * @param embedder embeddings to convert words to vectors
+     * @param batchSize mini batch size for DL4j training
+     * @param maxSeqLength truncate sequences that are longer than this.
+     *                    If truncation is not desired, set {@code Integer.MAX_VAL}
+     */
+    DataReader(String dataDirPath, List<String> labelNames, GlobalVectors embedder,
+               int batchSize, int maxSeqLength){
+        this.batchSize = batchSize;
+        this.embedder = embedder;
+        this.maxSeqLen = maxSeqLength;
+        this.vectorLen = embedder.getVectorSize();
+        this.numLabels = labelNames.size();
+        this.dataDir = new File(dataDirPath);
+        this.labelToId = new HashMap<>();
+        for (int i = 0; i < labelNames.size(); i++) {
+            labelToId.put(labelNames.get(i), i);
+        }
+        this.labelToId = Collections.unmodifiableMap(this.labelToId);
+        this.scanDir();
+        this.reset();
+    }
+
+    private void scanDir(){
+        assert dataDir.exists();
+        List<Integer> labels = new ArrayList<>();
+        List<File> files = new ArrayList<>();
+        for (String labelName: this.labelToId.keySet()) {
+            Integer labelId = this.labelToId.get(labelName);
+            assert labelId != null;
+            File labelDir = new File(dataDir, labelName);
+            if (!labelDir.exists()){
+                throw new IllegalStateException("No examples found for "
+                        + labelName + ". Looked at:" + labelDir);
+            }
+            File[] examples = labelDir.listFiles(f ->
+                    f.isFile() && f.getName().endsWith(this.extension));
+            if (examples == null || examples.length == 0){
+                throw new IllegalStateException("No examples found for "
+                        + labelName + ". Looked at:" + labelDir
+                        + " for files having extension: \" + extension");
+            }
+            LOG.info("Found {} examples for label {}", examples.length, labelName);
+            for (File example: examples) {
+                files.add(example);
+                labels.add(labelId);
+            }
+        }
+        this.records = files;
+        this.labels = labels;
+    }
+
+    /**
+     * sets tokenizer for converting text to tokens
+     * @param tokenizer tokenizer to use for converting text to tokens
+     */
+    public void setTokenizer(Function<String, String[]> tokenizer) {
+        this.tokenizer = tokenizer;
+    }
+
+    /**
+     * @return Tokenizer function used for converting text into words
+     */
+    public Function<String, String[]> getTokenizer() {
+        return tokenizer;
+    }
+
+    @Override
+    public DataSet next(int batchSize) {
+        batchSize = Math.min(batchSize, records.size() - cursor);
+        INDArray features = Nd4j.create(batchSize, vectorLen, maxSeqLen);
+        INDArray labels = Nd4j.create(batchSize, numLabels, maxSeqLen);
+
+        //Because we are dealing with text of different lengths and only one output at the final time step: use padding arrays
+        //Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
+        INDArray featuresMask = Nd4j.zeros(batchSize, maxSeqLen);
+        INDArray labelsMask = Nd4j.zeros(batchSize, maxSeqLen);
+
+        // Optimizations to speed up this code block by reusing memory
+        int _2dIndex[] = new int[2];
+        int _3dIndex[] = new int[3];
+        INDArrayIndex _3dNdIndex[] = new INDArrayIndex[]{null, NDArrayIndex.all(), null};
+
+        for (int i = 0; i < batchSize && cursor < records.size(); i++, cursor++) {
+            _2dIndex[0] = i;
+            _3dIndex[0] = i;
+            _3dNdIndex[0] = NDArrayIndex.point(i);
+
+            try {
+                // Read
+                File file = records.get(cursor);
+                int labelIdx = this.labels.get(cursor);
+                String text = FileUtils.readFileToString(file);
+                // Tokenize and Filter
+                String[] tokens = tokenizer.apply(text);
+                tokens = Arrays.stream(tokens).filter(embedder::hasWord).toArray(String[]::new);
+                //Get word vectors for each word in review, and put them in the training data
+                int j;
+                for(j = 0; j < tokens.length && j < maxSeqLen; j++ ){
+                    String token = tokens[j];
+                    INDArray vector = embedder.toVector(token);
+                    _3dNdIndex[2] = NDArrayIndex.point(j);
+                    features.put(_3dNdIndex, vector);
+                    //Word is present (not padding) for this example + time step -> 1.0 in features mask
+                    _2dIndex[1] = j;
+                    featuresMask.putScalar(_2dIndex, 1.0);
+                }
+                int lastIdx = j - 1;
+                _2dIndex[1] = lastIdx;
+                _3dIndex[1] = labelIdx;
+                _3dIndex[2] = lastIdx;
+
+                labels.putScalar(_3dIndex,1.0);   //Set label: one of k encoding
+                // Specify that an output exists at the final time step for this example
+                labelsMask.putScalar(_2dIndex,1.0);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+        //LOG.info("Cursor = {} || Init Time = {}, Read time = {}, preprocess Time = {}, Mask Time={}", cursor, initTime, readTime, preProcTime, maskTime);
+        return new DataSet(features, labels, featuresMask, labelsMask);
+    }
+
+    @Override
+    public int totalExamples() {
+        return this.records.size();
+    }
+
+    @Override
+    public int inputColumns() {
+        return this.embedder.getVectorSize();
+    }
+
+    @Override
+    public int totalOutcomes() {
+        return this.numLabels;
+    }
+
+    @Override
+    public boolean resetSupported() {
+        return true;
+    }
+
+    @Override
+    public boolean asyncSupported() {
+        return false;
+    }
+
+    @Override
+    public void reset() {
+        assert this.records.size() == this.labels.size();
+        long seed = System.nanoTime(); // shuffle both the lists in the same order
+        Collections.shuffle(this.records, new Random(seed));
+        Collections.shuffle(this.labels, new Random(seed));
+        this.cursor = 0; // from beginning
+    }
+
+    @Override
+    public int batch() {
+        return this.batchSize;
+    }
+
+    @Override
+    public int cursor() {
+        return this.cursor;
+    }
+
+    @Override
+    public int numExamples() {
+        return totalExamples();
+    }
+
+    @Override
+    public void setPreProcessor(DataSetPreProcessor preProcessor) {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public DataSetPreProcessor getPreProcessor() {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public List<String> getLabels() {
+        return new ArrayList<>(this.labelToId.keySet());
+    }
+
+    @Override
+    public boolean hasNext() {
+        return cursor < totalExamples() - 1;
+    }
+
+    @Override
+    public DataSet next() {
+        return next(this.batchSize);
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
new file mode 100644
index 0000000..fdf3a95
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.IOUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * GlobalVectors (Glove) for projecting words to vector space.
+ * This tool utilizes word vectors  pre-trained on large datasets.
+ *
+ * Visit https://nlp.stanford.edu/projects/glove/ for full documentation of Gloves.
+ *
+ * <h2>Usage</h2>
+ * <pre>
+ * path = "work/datasets/glove.6B/glove.6B.100d.txt";
+ * vocabSize = 20000; # max number of words to use
+ * GlobalVectors glove;
+ * try (InputStream stream = new FileInputStream(path)) {
+ *    glove = new GlobalVectors(stream, vocabSize);
+ * }
+ * </pre>
+ *
+ * @author Thamme Gowda (thammegowda@apache.org)
+ *
+ */
+public class GlobalVectors {
+
+    private static final Logger LOG = LoggerFactory.getLogger(GlobalVectors.class);
+
+    private final INDArray embeddings;
+    private final Map<String, Integer> wordToId;
+    private final List<String> idToWord;
+    private final int vectorSize;
+    private final int maxWords;
+
+    /**
+     * Reads Global Vectors from stream
+     * @param stream Glove word vectors stream (plain text)
+     * @throws IOException
+     */
+    public GlobalVectors(InputStream stream) throws IOException {
+        this(stream, Integer.MAX_VALUE);
+    }
+
+    /**
+     *
+     * @param stream vector stream
+     * @param maxWords maximum number of words to use, i.e. vocabulary size
+     * @throws IOException
+     */
+    public GlobalVectors(InputStream stream, int maxWords) throws IOException {
+        List<String> words = new ArrayList<>();
+        List<INDArray> vectors = new ArrayList<>();
+        int vectorSize = -1;
+        try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))){
+            String line;
+            while ((line = reader.readLine()) != null) {
+                String[] parts = line.split(" ");
+                if (vectorSize == -1) {
+                    vectorSize = parts.length - 1;
+                } else {
+                    assert vectorSize == parts.length - 1;
+                }
+                float[] vector = new float[vectorSize];
+                for (int i = 1; i < parts.length; i++) {
+                    vector[i-1] = Float.parseFloat(parts[i]);
+                }
+                vectors.add(Nd4j.create(vector));
+                words.add(parts[0]);
+                if (words.size() >= maxWords) {
+                    LOG.info("Max words reached at {}, aborting", words.size());
+                    break;
+                }
+            }
+            LOG.info("Found {} words; Vector dimensions={}", words.size(), vectorSize);
+            this.vectorSize = vectorSize;
+            this.maxWords = Math.min(words.size(), maxWords);
+            this.embeddings = Nd4j.create(vectors, new int[]{vectors.size(), vectorSize});
+            this.idToWord = words;
+            this.wordToId = new HashMap<>();
+            for (int i = 0; i < words.size(); i++) {
+                wordToId.put(words.get(i), i);
+            }
+        }
+    }
+
+    /**
+     * @return size or dimensions of vectors
+     */
+    public int getVectorSize() {
+        return vectorSize;
+    }
+
+    public int getMaxWords() {
+        return maxWords;
+    }
+
+    /**
+     *
+     * @param word
+     * @return {@code true} if word is known; false otherwise
+     */
+    public boolean hasWord(String word){
+        return wordToId.containsKey(word);
+    }
+
+    /**
+     * Converts word to vectors
+     * @param word word to be converted to vector
+     * @return Vector if words exists or null otherwise
+     */
+    public INDArray toVector(String word){
+        if (wordToId.containsKey(word)){
+            return embeddings.getRow(wordToId.get(word));
+        }
+        return null;
+    }
+
+    public INDArray embed(String text, int maxLen){
+        return embed(text.toLowerCase().split(" "), maxLen);
+    }
+
+    public INDArray embed(String[] tokens, int maxLen){
+        List<String> tokensFiltered = new ArrayList<>();
+        for(String t: tokens ){
+            if(hasWord(t)){
+                tokensFiltered.add(t);
+            }
+        }
+        int seqLen = Math.min(maxLen, tokensFiltered.size());
+
+        INDArray features = Nd4j.create(1, vectorSize, seqLen);
+
+        for( int j = 0; j < seqLen; j++ ){
+            String token = tokensFiltered.get(j);
+            INDArray vector = toVector(token);
+            features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
+        }
+        return features;
+    }
+
+    public void writeOut(OutputStream stream, boolean closeStream) throws IOException {
+        writeOut(stream, "%.5f", closeStream);
+    }
+
+    public void writeOut(OutputStream stream,
+                         String floatPrecisionFormatString, boolean closeStream) throws IOException {
+        if (!Character.isWhitespace(floatPrecisionFormatString.charAt(0))) {
+            floatPrecisionFormatString = " " + floatPrecisionFormatString;
+        }
+        LOG.info("Writing {} vectors out, float precision {}", idToWord.size(), floatPrecisionFormatString);
+
+        PrintWriter out = new PrintWriter(stream);
+        try {
+            for (int i = 0; i < idToWord.size(); i++) {
+                out.printf("%s", idToWord.get(i));
+                INDArray row = embeddings.getRow(i);
+                for (int j = 0; j < vectorSize; j++) {
+                    out.printf(floatPrecisionFormatString, row.getDouble(j));
+                }
+                out.println();
+            }
+        } finally {
+            if (closeStream){
+                IOUtils.closeQuietly(out);
+            } // else dont close because, closing the print writer also closes the inner stream
+        }
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
new file mode 100644
index 0000000..53bf530
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import opennlp.tools.doccat.DocumentCategorizer;
+import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WhitespaceTokenizer;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.NotImplementedException;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+
+/**
+ * An implementation of {@link DocumentCategorizer} using Neural Networks.
+ * This class provides prediction functionality from the model of {@link NeuralDocCatTrainer}.
+ *
+ */
+public class NeuralDocCat implements DocumentCategorizer {
+
+    private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCat.class);
+
+    private NeuralDocCatModel model;
+
+    public NeuralDocCat(NeuralDocCatModel model) {
+        this.model = model;
+    }
+
+    @Override
+    public double[] categorize(String[] tokens) {
+        return categorize(tokens, Collections.emptyMap());
+    }
+
+    @Override
+    public double[] categorize(String[] text, Map<String, Object> extraInformation) {
+        INDArray seqFeatures = this.model.getGloves().embed(text, this.model.getMaxSeqLen());
+
+        INDArray networkOutput = this.model.getNetwork().output(seqFeatures);
+        int timeSeriesLength = networkOutput.size(2);
+        INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
+                NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
+
+        int nLabels = this.model.getLabels().size();
+        double[] probs = new double[nLabels];
+        for (int i = 0; i < nLabels; i++) {
+            probs[i] = probsAtLastWord.getDouble(i);
+        }
+        return probs;
+    }
+
+    @Override
+    public String getBestCategory(double[] outcome) {
+        int maxIdx = 0;
+        double maxProb = outcome[0];
+        for (int i = 1; i < outcome.length; i++) {
+            if (outcome[i] > maxProb) {
+                maxIdx = i;
+                maxProb = outcome[i];
+            }
+        }
+        return model.getLabels().get(maxIdx);
+    }
+
+    @Override
+    public int getIndex(String category) {
+        return model.getLabels().indexOf(category);
+    }
+
+    @Override
+    public String getCategory(int index) {
+        return model.getLabels().get(index);
+    }
+
+    @Override
+    public int getNumberOfCategories() {
+        return model.getLabels().size();
+    }
+
+
+    @Override
+    public String getAllResults(double[] results) {
+        throw new NotImplementedException("Not implemented");
+    }
+
+    @Override
+    public Map<String, Double> scoreMap(String[] text) {
+        double[] scores = categorize(text);
+        Map<String, Double> result = new HashMap<>();
+        for (int i = 0; i < scores.length; i++) {
+            result.put(model.getLabels().get(i), scores[i]);
+
+        }
+        return result;
+    }
+
+    @Override
+    public SortedMap<Double, Set<String>> sortedScoreMap(String[] text) {
+        throw new NotImplementedException("Not implemented");
+    }
+
+    @Override
+    @Deprecated
+    public double[] categorize(String documentText) {
+        throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+    }
+
+    @Override
+    @Deprecated
+    public Map<String, Double> scoreMap(String text) {
+        throw new UnsupportedOperationException("Use the other scoreMap(..) method that accepts tokenized text");
+    }
+
+    @Override
+    @Deprecated
+    public SortedMap<Double, Set<String>> sortedScoreMap(String text) {
+        throw new UnsupportedOperationException("Use the other sortedScoreMap(..) method that accepts tokenized text");
+    }
+    @Override
+    @Deprecated
+    public double[] categorize(String documentText, Map<String, Object> extraInformation) {
+        throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+    }
+
+
+    public static void main(String[] argss) throws CmdLineException, IOException {
+        class Args {
+
+            @Option(name = "-model", required = true, usage = "Path to NeuralDocCatModel stored file")
+            String modelPath;
+
+            @Option(name = "-files", required = true, usage = "One or more document paths whose category is " +
+                    "to be predicted by the model")
+            List<File> files;
+        }
+
+        Args args = new Args();
+        CmdLineParser parser = new CmdLineParser(args);
+        try {
+            parser.parseArgument(argss);
+        } catch (CmdLineException e) {
+            System.out.println(e.getMessage());
+            e.getParser().printUsage(System.out);
+            System.exit(1);
+        }
+
+        NeuralDocCatModel model = NeuralDocCatModel.loadModel(args.modelPath);
+        NeuralDocCat classifier = new NeuralDocCat(model);
+
+        System.out.println("Labels:" + model.getLabels());
+        Tokenizer tokenizer = WhitespaceTokenizer.INSTANCE;
+
+        for (File file: args.files) {
+            String text = FileUtils.readFileToString(file);
+            String[] tokens = tokenizer.tokenize(text.toLowerCase());
+            double[] probs = classifier.categorize(tokens);
+            System.out.println(">>" + file);
+            System.out.println("Probabilities:" + Arrays.toString(probs));
+        }
+
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
new file mode 100644
index 0000000..f1b6247
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
@@ -0,0 +1,179 @@
+package opennlp.tools.dl;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.*;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipInputStream;
+import java.util.zip.ZipOutputStream;
+
+/**
+ * This class is a wrapper for DL4J's {@link MultiLayerNetwork}, and {@link GlobalVectors}
+ * that provides features to serialize and deserialize necessary data to a zip file.
+ *
+ * This cane be used by a Neural Trainer tool to serialize the network and a predictor tool to restore the same network
+ * with the weights.
+ *
+ * <br/>
+ ** @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatModel {
+
+    public static final int VERSION = 1;
+    public static final String MODEL_NAME = NeuralDocCatModel.class.getName();
+    public static final String MANIFEST = "model.mf";
+    public static final String NETWORK = "network.json";
+    public static final String WEIGHTS = "weights.bin";
+    public static final String GLOVES = "gloves.tsv";
+    public static final String LABELS = "labels";
+    public static final String MAX_SEQ_LEN = "maxSeqLen";
+
+    private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatModel.class);
+
+    private final MultiLayerNetwork network;
+    private final GlobalVectors gloves;
+    private final Properties manifest;
+    private final List<String> labels;
+    private final int maxSeqLen;
+
+    /**
+     *
+     * @param stream Input stream of a Zip File
+     * @throws IOException
+     */
+    public NeuralDocCatModel(InputStream stream) throws IOException {
+        ZipInputStream zipIn = new ZipInputStream(stream);
+
+        Properties manifest = null;
+        MultiLayerNetwork model = null;
+        INDArray params = null;
+        GlobalVectors gloves = null;
+        ZipEntry entry;
+        while ((entry = zipIn.getNextEntry()) != null) {
+            String name = entry.getName();
+            switch (name) {
+                case MANIFEST:
+                    manifest = new Properties();
+                    manifest.load(zipIn);
+                    break;
+                case NETWORK:
+                    String json = IOUtils.toString(new UnclosableInputStream(zipIn));
+                    model = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
+                    break;
+                case WEIGHTS:
+                    params = Nd4j.read(new DataInputStream(new UnclosableInputStream(zipIn)));
+                    break;
+                case GLOVES:
+                    gloves = new GlobalVectors(new UnclosableInputStream(zipIn));
+                    break;
+                default:
+                    LOG.warn("Unexpected entry in the zip : {}", name);
+            }
+        }
+
+        assert model != null;
+        assert manifest != null;
+        model.init(params, false);
+        this.network = model;
+        this.manifest = manifest;
+        this.gloves = gloves;
+
+        assert manifest.containsKey(LABELS);
+        String[] labels = manifest.getProperty(LABELS).split(",");
+        this.labels = Collections.unmodifiableList(Arrays.asList(labels));
+
+        assert manifest.containsKey(MAX_SEQ_LEN);
+        this.maxSeqLen = Integer.parseInt(manifest.getProperty(MAX_SEQ_LEN));
+
+    }
+
+    /**
+     *
+     * @param network any compatible multi layer neural network
+     * @param vectors Global vectors
+     * @param labels list of labels
+     * @param maxSeqLen max sequence length
+     */
+    public NeuralDocCatModel(MultiLayerNetwork network, GlobalVectors vectors, List<String> labels, int maxSeqLen) {
+        this.network = network;
+        this.gloves = vectors;
+        this.manifest = new Properties();
+        this.manifest.setProperty(LABELS, StringUtils.join(labels, ","));
+        this.manifest.setProperty(MAX_SEQ_LEN, maxSeqLen + "");
+        this.labels = Collections.unmodifiableList(labels);
+        this.maxSeqLen = maxSeqLen;
+    }
+
+    public MultiLayerNetwork getNetwork() {
+        return network;
+    }
+
+    public GlobalVectors getGloves() {
+        return gloves;
+    }
+
+    public List<String> getLabels() {
+        return labels;
+    }
+
+    public int getMaxSeqLen() {
+        return this.maxSeqLen;
+    }
+
+    /**
+     * Zips the current state of the model and writes it stream
+     * @param stream stream to write
+     * @throws IOException
+     */
+    public void saveModel(OutputStream stream) throws IOException {
+        try (ZipOutputStream zipOut = new ZipOutputStream(new BufferedOutputStream(stream))) {
+            // Write out manifest
+            zipOut.putNextEntry(new ZipEntry(MANIFEST));
+
+            String comments = "Created-By:" + System.getenv("USER") + " at " + new Date().toString()
+                    + "\nModel-Version: " + VERSION
+                    + "\nModel-Schema:" + MODEL_NAME;
+
+            manifest.store(zipOut, comments);
+            zipOut.closeEntry();
+
+            // Write out the network
+            zipOut.putNextEntry(new ZipEntry(NETWORK));
+            byte[] jModel = network.getLayerWiseConfigurations().toJson().getBytes();
+            zipOut.write(jModel);
+            zipOut.closeEntry();
+
+            //Write out the network coefficients
+            zipOut.putNextEntry(new ZipEntry(WEIGHTS));
+            Nd4j.write(network.params(), new DataOutputStream(zipOut));
+            zipOut.closeEntry();
+
+            // Write out vectors
+            zipOut.putNextEntry(new ZipEntry(GLOVES));
+            gloves.writeOut(zipOut, false);
+            zipOut.closeEntry();
+
+            zipOut.finish();
+        }
+    }
+
+    /**
+     * creates a model from file on the local file system
+     * @param modelPath path to model file
+     * @return an instance of this class
+     * @throws IOException
+     */
+    public static NeuralDocCatModel loadModel(String modelPath) throws IOException {
+        try (InputStream modelStream = new FileInputStream(modelPath)) {
+            return new NeuralDocCatModel(modelStream);
+        }
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
new file mode 100644
index 0000000..4099b65
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
@@ -0,0 +1,253 @@
+package opennlp.tools.dl;
+
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.kohsuke.args4j.spi.StringArrayOptionHandler;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.List;
+
+
+/**
+ * This class provides functionality to construct and train neural networks that can be used for
+ * {@link opennlp.tools.doccat.DocumentCategorizer}
+ *
+ * @see NeuralDocCat
+ * @see NeuralDocCatModel
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatTrainer {
+
+    public static class Args {
+
+        @Option(name = "-batchSize", usage = "Number of examples in minibatch")
+        int batchSize = 128;
+
+        @Option(name = "-nEpochs", usage = "Number of epochs (i.e. full passes over the training data) to train on." +
+                " Applicable for training only.")
+        int nEpochs = 2;
+
+        @Option(name = "-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
+        int maxSeqLen = 256;    //Truncate text with length (# words) greater than this
+
+        @Option(name = "-vocabSize", usage = "Vocabulary Size.")
+        int vocabSize = 20000;   //vocabulary size
+
+        @Option(name = "-nRNNUnits", usage = "Number of RNN cells to use.")
+        int nRNNUnits = 128;
+
+        @Option(name = "-lr", aliases = "-learnRate", usage = "Learning Rate." +
+                " Adjust it when the scores bounce to NaN or Infinity.")
+        double learningRate = 2e-3;
+
+        @Option(name = "-glovesPath", required = true, usage = "Path to GloVe vectors file." +
+                " Download and unzip from https://nlp.stanford.edu/projects/glove/")
+        String glovesPath = null;
+
+        @Option(name = "-modelPath", required = true, usage = "Path to model file. " +
+                "This will be used for serializing the model after the training phase." )
+        String modelPath = null;
+
+        @Option(name = "-trainDir", required = true, usage = "Path to train data directory." +
+                " Setting this value will take the system to training mode. ")
+        String trainDir = null;
+
+        @Option(name = "-validDir", usage = "Path to validation data directory. Optional.")
+        String validDir = null;
+
+        @Option(name = "-labels", required = true, handler = StringArrayOptionHandler.class,
+                usage = "Names of targets or labels separated by spaces. " +
+                        "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
+                        "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
+                        "applicable. \n Example -labels pos neg")
+        List<String> labels = null;
+
+        @Override
+        public String toString() {
+            return "Args{" +
+                    "batchSize=" + batchSize +
+                    ", nEpochs=" + nEpochs +
+                    ", maxSeqLen=" + maxSeqLen +
+                    ", vocabSize=" + vocabSize +
+                    ", learningRate=" + learningRate +
+                    ", nRNNUnits=" + nRNNUnits +
+                    ", glovesPath='" + glovesPath + '\'' +
+                    ", modelPath='" + modelPath + '\'' +
+                    ", trainDir='" + trainDir + '\'' +
+                    ", validDir='" + validDir + '\'' +
+                    ", labels=" + labels +
+                    '}';
+        }
+    }
+
+    private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatTrainer.class);
+
+    private NeuralDocCatModel model;
+    private Args args;
+    private DataReader trainSet;
+    private DataReader validSet;
+
+
+    public NeuralDocCatTrainer(Args args) throws IOException {
+        this.args = args;
+        GlobalVectors gloves;
+        MultiLayerNetwork network;
+
+        try (InputStream stream = new FileInputStream(args.glovesPath)) {
+            gloves = new GlobalVectors(stream, args.vocabSize);
+        }
+
+        LOG.info("Training data from {}", args.trainDir);
+        this.trainSet = new DataReader(args.trainDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+        if (args.validDir != null) {
+            LOG.info("Validation data from {}", args.validDir);
+            this.validSet = new DataReader(args.validDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+        }
+
+        //create network
+        network = this.createNetwork(gloves.getVectorSize());
+        this.model = new NeuralDocCatModel(network, gloves, args.labels, args.maxSeqLen);
+    }
+
+    public MultiLayerNetwork createNetwork(int vectorSize) {
+        int totalOutcomes = this.trainSet.totalOutcomes();
+        assert totalOutcomes >= 2;
+        LOG.info("Number of classes " + totalOutcomes);
+
+        //TODO: the below network params should be configurable from CLI or settings file
+        //Set up network configuration
+        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+                .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
+                .rmsDecay(0.9)
+                .regularization(true).l2(1e-5)
+                .weightInit(WeightInit.XAVIER)
+                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+                .gradientNormalizationThreshold(1.0)
+                .learningRate(args.learningRate)
+                .list()
+                .layer(0, new GravesLSTM.Builder()
+                        .nIn(vectorSize)
+                        .nOut(args.nRNNUnits)
+                        .activation(Activation.RELU).build())
+                .layer(1, new RnnOutputLayer.Builder()
+                        .nIn(args.nRNNUnits)
+                        .nOut(totalOutcomes)
+                        .activation(Activation.SOFTMAX)
+                        .lossFunction(LossFunctions.LossFunction.MCXENT)
+                        .build())
+                .pretrain(false)
+                .backprop(true)
+                .build();
+
+        MultiLayerNetwork net = new MultiLayerNetwork(conf);
+        net.init();
+        net.setListeners(new ScoreIterationListener(1));
+        return net;
+    }
+
+    public void train() {
+        train(args.nEpochs, this.trainSet, this.validSet);
+    }
+
+    /**
+     * Trains model
+     *
+     * @param nEpochs    number of epochs (i.e. iterations over the training dataset)
+     * @param train      training data set
+     * @param validation validation data set for evaluation after each epoch.
+     *                   Setting this to null will skip the evaluation
+     */
+    public void train(int nEpochs, DataReader train, DataReader validation) {
+        assert model != null;
+        assert train != null;
+        LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
+                train.totalExamples(), validation == null ? null : validation.totalExamples());
+        for (int i = 0; i < nEpochs; i++) {
+            model.getNetwork().fit(train);
+            train.reset();
+            LOG.info("Epoch {} complete", i);
+
+            if (validation != null) {
+                LOG.info("Starting evaluation");
+                //Run evaluation. This is on 25k reviews, so can take some time
+                Evaluation evaluation = new Evaluation();
+                while (validation.hasNext()) {
+                    DataSet t = validation.next();
+                    INDArray features = t.getFeatureMatrix();
+                    INDArray labels = t.getLabels();
+                    INDArray inMask = t.getFeaturesMaskArray();
+                    INDArray outMask = t.getLabelsMaskArray();
+                    INDArray predicted = this.model.getNetwork().output(features, false, inMask, outMask);
+                    evaluation.evalTimeSeries(labels, predicted, outMask);
+                }
+                validation.reset();
+                LOG.info(evaluation.stats());
+            }
+        }
+    }
+
+    /**
+     * Saves the model to specified path
+     *
+     * @param path model path
+     * @throws IOException
+     */
+    public void saveModel(String path) throws IOException {
+        assert model != null;
+        LOG.info("Saving the model at {}", path);
+        try (OutputStream stream = new FileOutputStream(path)) {
+            model.saveModel(stream);
+        }
+    }
+
+    /**
+     * <pre>
+     *   # Download pre trained Glo-ves (this is a large file)
+     *   wget http://nlp.stanford.edu/data/glove.6B.zip
+     *   unzip glove.6B.zip -d glove.6B
+     *
+     *   # Download dataset
+     *   wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+     *   tar xzf aclImdb_v1.tar.gz
+     *
+     *  mvn compile exec:java
+     *    -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.NeuralDocCat
+     *    -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
+     *    -labels pos neg -modelPath imdb-sentiment-neural-model.zip
+     *    -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
+     *
+     * </pre>
+     */
+    public static void main(String[] argss) throws CmdLineException, IOException {
+        Args args = new Args();
+        CmdLineParser parser = new CmdLineParser(args);
+        try {
+            parser.parseArgument(argss);
+        } catch (CmdLineException e) {
+            System.out.println(e.getMessage());
+            e.getParser().printUsage(System.out);
+            System.exit(1);
+        }
+        NeuralDocCatTrainer classifier = new NeuralDocCatTrainer(args);
+        classifier.train();
+        classifier.saveModel(args.modelPath);
+    }
+
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
new file mode 100644
index 0000000..701fc48
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
@@ -0,0 +1,56 @@
+package opennlp.tools.dl;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.Reader;
+import java.io.Writer;
+
+/**
+ * This class offers a wrapper for {@link InputStream};
+ * The only sole purpose of this wrapper is to bypass the close calls that are usually
+ * propagated from the readers.
+ * A use case of this wrapper is for reading multiple files from the {@link java.util.zip.ZipInputStream},
+ * especially because the tools like {@link org.apache.commons.io.IOUtils#copy(Reader, Writer)}
+ * and {@link org.nd4j.linalg.factory.Nd4j#read(InputStream)} automatically close the input stream.
+ *
+ * Note:
+ *  1. this tool ignores the call to {@link #close()} method
+ *  2. Remember to call {@link #forceClose()} when the stream when the inner stream needs to be closed
+ *  3. This wrapper doesn't hold any resources. If you close the innerStream, you can safely ignore closing this wrapper
+ *
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class UnclosableInputStream extends InputStream {
+
+    private InputStream innerStream;
+
+    public UnclosableInputStream(InputStream stream){
+        this.innerStream = stream;
+    }
+
+    @Override
+    public int read() throws IOException {
+        return innerStream.read();
+    }
+
+    /**
+     * NOP - Does not close the stream - intentional
+     * @throws IOException
+     */
+    @Override
+    public void close() throws IOException {
+        // intentionally ignored;
+        // Use forceClose() when needed to close
+    }
+
+    /**
+     * Closes the stream
+     * @throws IOException
+     */
+    public void forceClose() throws IOException {
+        if (innerStream != null) {
+            innerStream.close();
+            innerStream = null;
+        }
+    }
+}