Refactored and implemented DocCat API
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
index 825cf53..86af123 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
@@ -82,9 +82,9 @@
  * </code>
  *
  * @see GlobalVectors
- * @see GloveRNNTextClassifier
+ * @see NeuralDocCat
  * <br/>
- * Contributors: Thamme Gowda (thammegowda@apache.org)
+ * @author Thamme Gowda (thammegowda@apache.org)
  *
  */
 public class DataReader implements DataSetIterator {
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
index bda4531..fdf3a95 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
@@ -18,6 +18,7 @@
  */
 package opennlp.tools.dl;
 
+import org.apache.commons.io.IOUtils;
 import org.nd4j.linalg.api.ndarray.INDArray;
 import org.nd4j.linalg.dataset.DataSet;
 import org.nd4j.linalg.factory.Nd4j;
@@ -26,11 +27,7 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.BufferedReader;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
+import java.io.*;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
@@ -53,6 +50,8 @@
  * }
  * </pre>
  *
+ * @author Thamme Gowda (thammegowda@apache.org)
+ *
  */
 public class GlobalVectors {
 
@@ -169,4 +168,32 @@
         }
         return features;
     }
+
+    public void writeOut(OutputStream stream, boolean closeStream) throws IOException {
+        writeOut(stream, "%.5f", closeStream);
+    }
+
+    public void writeOut(OutputStream stream,
+                         String floatPrecisionFormatString, boolean closeStream) throws IOException {
+        if (!Character.isWhitespace(floatPrecisionFormatString.charAt(0))) {
+            floatPrecisionFormatString = " " + floatPrecisionFormatString;
+        }
+        LOG.info("Writing {} vectors out, float precision {}", idToWord.size(), floatPrecisionFormatString);
+
+        PrintWriter out = new PrintWriter(stream);
+        try {
+            for (int i = 0; i < idToWord.size(); i++) {
+                out.printf("%s", idToWord.get(i));
+                INDArray row = embeddings.getRow(i);
+                for (int j = 0; j < vectorSize; j++) {
+                    out.printf(floatPrecisionFormatString, row.getDouble(j));
+                }
+                out.println();
+            }
+        } finally {
+            if (closeStream){
+                IOUtils.closeQuietly(out);
+            } // else dont close because, closing the print writer also closes the inner stream
+        }
+    }
 }
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
deleted file mode 100644
index dd7bcc5..0000000
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
+++ /dev/null
@@ -1,334 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *  http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package opennlp.tools.dl;
-
-import org.apache.commons.io.FileUtils;
-import org.deeplearning4j.eval.Evaluation;
-import org.deeplearning4j.nn.conf.GradientNormalization;
-import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
-import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
-import org.deeplearning4j.nn.conf.Updater;
-import org.deeplearning4j.nn.conf.layers.GravesLSTM;
-import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
-import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
-import org.deeplearning4j.nn.weights.WeightInit;
-import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
-import org.deeplearning4j.util.ModelSerializer;
-import org.kohsuke.args4j.CmdLineException;
-import org.kohsuke.args4j.CmdLineParser;
-import org.kohsuke.args4j.Option;
-import org.kohsuke.args4j.spi.StringArrayOptionHandler;
-import org.nd4j.linalg.activations.Activation;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.dataset.DataSet;
-import org.nd4j.linalg.indexing.NDArrayIndex;
-import org.nd4j.linalg.lossfunctions.LossFunctions;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.util.Arrays;
-import java.util.List;
-
-/**
- * This is a Multi Class Text Classifier that uses Glove embeddings to vectorize the text
- * and a LSTM RNN to classify the sequence of vectors.
- *
- * This class aimed to make a general purpose text classifier tool.
- * A common use case would be to tune it for the text Sentiment classification task.
- *
- * The Glove Vectors can be downloaded from https://nlp.stanford.edu/projects/glove/
- *
- * <br/>
- */
-public class GloveRNNTextClassifier {
-    private static final Logger LOG = LoggerFactory.getLogger(GloveRNNTextClassifier.class);
-
-    private MultiLayerNetwork model;
-    private GlobalVectors gloves;
-    private DataReader trainSet;
-    private DataReader validSet;
-
-    private Args args;
-
-    public GloveRNNTextClassifier(Args args) throws IOException {
-        this.init(args);
-    }
-
-    public static class Args {
-
-        @Option(name="-batchSize", depends = {"-trainDir"},
-                usage = "Number of examples in minibatch. Applicable for training only.")
-        int batchSize = 128;
-
-        @Option(name="-nEpochs", depends = {"-trainDir"},
-                usage = "Number of epochs (i.e. full passes over the training data) to train on." +
-                " Applicable for training only.")
-        int nEpochs = 2;
-
-        @Option(name="-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
-        int maxSeqLen = 256;    //Truncate text with length (# words) greater than this
-
-        @Option(name="-vocabSize", usage = "Vocabulary Size.")
-        int vocabSize = 20000;   //vocabulary size
-
-        @Option(name="-nRNNUnits", depends = {"-trainDir"},
-                usage = "Number of RNN cells to use. Applicable for training only.")
-        int nRNNUnits = 128;
-
-        @Option(name="-lr", aliases = "-learnRate", usage = "Learning Rate." +
-                " Adjust it when the scores bounce to NaN or Infinity.")
-        double learningRate = 2e-3;
-
-        @Option(name="-glovesPath", required = true, usage = "Path to GloVe vectors file." +
-                " Download and unzip from https://nlp.stanford.edu/projects/glove/")
-        String glovesPath = null;
-
-        @Option(name="-modelPath", required = true, usage = "Path to model file. " +
-                "This will be used for serializing the model after the training phase." +
-                "and also the model will be restored from here for prediction")
-        String modelPath = null;
-
-        @Option(name="-trainDir", usage = "Path to train data directory. Optional." +
-                " Setting this value will take the system to training mode. ")
-        String trainDir = null;
-
-        @Option(name="-validDir", depends = {"-trainDir"}, usage = "Path to validation data directory. Optional." +
-                " Applicable only when -trainDir is set.")
-        String validDir = null;
-
-        @Option(name="-labels", required = true, handler = StringArrayOptionHandler.class,
-                usage = "Names of targets or labels separated by spaces. " +
-                "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
-                "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
-                        "applicable. \n Example -labels pos neg")
-        List<String> labels = null;
-
-        @Option(name="-files", handler = StringArrayOptionHandler.class,
-                usage = "File paths (separated by space) to predict using the model.")
-        List<String> filePaths = null;
-
-        @Override
-        public String toString() {
-            return "Args{" +
-                    "batchSize=" + batchSize +
-                    ", nEpochs=" + nEpochs +
-                    ", maxSeqLen=" + maxSeqLen +
-                    ", vocabSize=" + vocabSize +
-                    ", learningRate=" + learningRate +
-                    ", nRNNUnits=" + nRNNUnits +
-                    ", glovesPath='" + glovesPath + '\'' +
-                    ", modelPath='" + modelPath + '\'' +
-                    ", trainDir='" + trainDir + '\'' +
-                    ", validDir='" + validDir + '\'' +
-                    ", labels=" + labels +
-                    '}';
-        }
-    }
-
-
-    public void init(Args args) throws IOException {
-        this.args = args;
-        try (InputStream stream = new FileInputStream(args.glovesPath)) {
-            this.gloves = new GlobalVectors(stream, args.vocabSize);
-        }
-
-        if (args.trainDir != null) {
-            LOG.info("Training data from {}", args.trainDir);
-            this.trainSet = new DataReader(args.trainDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
-            if (args.validDir != null) {
-                LOG.info("Validation data from {}", args.validDir);
-                this.validSet = new DataReader(args.validDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
-            }
-            //create model
-            this.model = this.createModel();
-            // ready for training
-        } else {
-            //restore model
-            LOG.info("Training data not set => Going to restore model from {}", args.modelPath);
-            this.model = ModelSerializer.restoreMultiLayerNetwork(args.modelPath);
-            //ready for prediction
-        }
-    }
-
-    private MultiLayerNetwork createModel(){
-        int totalOutcomes = this.trainSet.totalOutcomes();
-        assert totalOutcomes >= 2;
-        LOG.info("Number of classes " + totalOutcomes);
-
-        //TODO: the below network params should be configurable from CLI or settings file
-        //Set up network configuration
-        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
-                .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
-                .rmsDecay(0.9)
-                .regularization(true).l2(1e-5)
-                .weightInit(WeightInit.XAVIER)
-                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
-                .gradientNormalizationThreshold(1.0)
-                .learningRate(args.learningRate)
-                .list()
-                .layer(0, new GravesLSTM.Builder()
-                        .nIn(gloves.getVectorSize())
-                        .nOut(args.nRNNUnits)
-                        .activation(Activation.RELU).build())
-                .layer(1, new RnnOutputLayer.Builder()
-                        .nIn(args.nRNNUnits)
-                        .nOut(totalOutcomes)
-                        .activation(Activation.SOFTMAX)
-                        .lossFunction(LossFunctions.LossFunction.MCXENT)
-                        .build())
-                .pretrain(false)
-                .backprop(true)
-                .build();
-
-        MultiLayerNetwork net = new MultiLayerNetwork(conf);
-        net.init();
-        net.setListeners(new ScoreIterationListener(1));
-        return net;
-    }
-
-    public void train(){
-        train(args.nEpochs, this.trainSet, this.validSet);
-    }
-
-    /**
-     * Trains model
-     * @param nEpochs number of epochs (i.e. iterations over the training dataset)
-     * @param train training data set
-     * @param validation validation data set for evaluation after each epoch.
-     *                  Setting this to null will skip the evaluation
-     */
-    public void train(int nEpochs, DataReader train, DataReader validation){
-        assert model != null;
-        assert train != null;
-        LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
-                train.totalExamples(), validation == null ? null : validation.totalExamples());
-        for (int i = 0; i < nEpochs; i++) {
-            model.fit(train);
-            train.reset();
-            LOG.info("Epoch {} complete", i);
-
-            if (validation != null) {
-                LOG.info("Starting evaluation");
-                //Run evaluation. This is on 25k reviews, so can take some time
-                Evaluation evaluation = new Evaluation();
-                while (validation.hasNext()) {
-                    DataSet t = validation.next();
-                    INDArray features = t.getFeatureMatrix();
-                    INDArray labels = t.getLabels();
-                    INDArray inMask = t.getFeaturesMaskArray();
-                    INDArray outMask = t.getLabelsMaskArray();
-                    INDArray predicted = model.output(features, false, inMask, outMask);
-                    evaluation.evalTimeSeries(labels, predicted, outMask);
-                }
-                validation.reset();
-                LOG.info(evaluation.stats());
-            }
-        }
-    }
-
-    /**
-     * Saves the model to specified path
-     * @param path model path
-     * @throws IOException
-     */
-    public void saveModel(String path) throws IOException {
-        assert model != null;
-        LOG.info("Saving the model at {}", path);
-        ModelSerializer.writeModel(model, path, true);
-    }
-
-    /**
-     * Predicts class probability of the text based on the state of model
-     * @param text text to be classified
-     * @return array of doubles, indices associated with indices of <code>this.args.labels</code>
-     */
-    public double[] predict(String text){
-
-        INDArray seqFeatures = this.gloves.embed(text, this.args.maxSeqLen);
-        INDArray networkOutput = this.model.output(seqFeatures);
-        int timeSeriesLength = networkOutput.size(2);
-        INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
-                NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
-
-        double[] probs = new double[args.labels.size()];
-        for (int i = 0; i < args.labels.size(); i++) {
-            probs[i] = probsAtLastWord.getDouble(i);
-        }
-        return probs;
-    }
-
-    /**
-     * <pre>
-     *   # Download pre trained Glo-ves (this is a large file)
-     *   wget http://nlp.stanford.edu/data/glove.6B.zip
-     *   unzip glove.6B.zip -d glove.6B
-     *
-     *   # Download dataset
-     *   wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
-     *   tar xzf aclImdb_v1.tar.gz
-     *
-     *  mvn compile exec:java
-     *    -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.GloveRNNTextClassifier
-     *    -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
-     *    -labels pos neg -modelPath imdb-sentimodel.dat
-     *    -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
-     *
-     * </pre>
-     *
-     */
-    public static void main(String[] argss) throws CmdLineException, IOException {
-        Args args = new Args();
-        CmdLineParser parser = new CmdLineParser(args);
-        try {
-            parser.parseArgument(argss);
-        } catch (CmdLineException e) {
-            System.out.println(e.getMessage());
-            e.getParser().printUsage(System.out);
-            System.exit(1);
-        }
-        GloveRNNTextClassifier classifier = new GloveRNNTextClassifier(args);
-        byte numOps = 0;
-        if (classifier.trainSet != null) {
-            numOps++;
-            classifier.train();
-            classifier.saveModel(args.modelPath);
-        }
-
-        if (args.filePaths != null && !args.filePaths.isEmpty()) {
-            numOps++;
-            System.out.println("Labels:" + args.labels);
-            for (String filePath: args.filePaths) {
-                File file = new File(filePath);
-                String text = FileUtils.readFileToString(file);
-                double[] probs = classifier.predict(text);
-                System.out.println(">>" + filePath);
-                System.out.println("Probabilities:" + Arrays.toString(probs));
-            }
-        }
-
-        if (numOps == 0) {
-            System.out.println("Provide -trainDir to train a model, -files to classify files");
-            System.exit(2);
-        }
-    }
-}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
new file mode 100644
index 0000000..53bf530
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import opennlp.tools.doccat.DocumentCategorizer;
+import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WhitespaceTokenizer;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.NotImplementedException;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+
+/**
+ * An implementation of {@link DocumentCategorizer} using Neural Networks.
+ * This class provides prediction functionality from the model of {@link NeuralDocCatTrainer}.
+ *
+ */
+public class NeuralDocCat implements DocumentCategorizer {
+
+    private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCat.class);
+
+    private NeuralDocCatModel model;
+
+    public NeuralDocCat(NeuralDocCatModel model) {
+        this.model = model;
+    }
+
+    @Override
+    public double[] categorize(String[] tokens) {
+        return categorize(tokens, Collections.emptyMap());
+    }
+
+    @Override
+    public double[] categorize(String[] text, Map<String, Object> extraInformation) {
+        INDArray seqFeatures = this.model.getGloves().embed(text, this.model.getMaxSeqLen());
+
+        INDArray networkOutput = this.model.getNetwork().output(seqFeatures);
+        int timeSeriesLength = networkOutput.size(2);
+        INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
+                NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
+
+        int nLabels = this.model.getLabels().size();
+        double[] probs = new double[nLabels];
+        for (int i = 0; i < nLabels; i++) {
+            probs[i] = probsAtLastWord.getDouble(i);
+        }
+        return probs;
+    }
+
+    @Override
+    public String getBestCategory(double[] outcome) {
+        int maxIdx = 0;
+        double maxProb = outcome[0];
+        for (int i = 1; i < outcome.length; i++) {
+            if (outcome[i] > maxProb) {
+                maxIdx = i;
+                maxProb = outcome[i];
+            }
+        }
+        return model.getLabels().get(maxIdx);
+    }
+
+    @Override
+    public int getIndex(String category) {
+        return model.getLabels().indexOf(category);
+    }
+
+    @Override
+    public String getCategory(int index) {
+        return model.getLabels().get(index);
+    }
+
+    @Override
+    public int getNumberOfCategories() {
+        return model.getLabels().size();
+    }
+
+
+    @Override
+    public String getAllResults(double[] results) {
+        throw new NotImplementedException("Not implemented");
+    }
+
+    @Override
+    public Map<String, Double> scoreMap(String[] text) {
+        double[] scores = categorize(text);
+        Map<String, Double> result = new HashMap<>();
+        for (int i = 0; i < scores.length; i++) {
+            result.put(model.getLabels().get(i), scores[i]);
+
+        }
+        return result;
+    }
+
+    @Override
+    public SortedMap<Double, Set<String>> sortedScoreMap(String[] text) {
+        throw new NotImplementedException("Not implemented");
+    }
+
+    @Override
+    @Deprecated
+    public double[] categorize(String documentText) {
+        throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+    }
+
+    @Override
+    @Deprecated
+    public Map<String, Double> scoreMap(String text) {
+        throw new UnsupportedOperationException("Use the other scoreMap(..) method that accepts tokenized text");
+    }
+
+    @Override
+    @Deprecated
+    public SortedMap<Double, Set<String>> sortedScoreMap(String text) {
+        throw new UnsupportedOperationException("Use the other sortedScoreMap(..) method that accepts tokenized text");
+    }
+    @Override
+    @Deprecated
+    public double[] categorize(String documentText, Map<String, Object> extraInformation) {
+        throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+    }
+
+
+    public static void main(String[] argss) throws CmdLineException, IOException {
+        class Args {
+
+            @Option(name = "-model", required = true, usage = "Path to NeuralDocCatModel stored file")
+            String modelPath;
+
+            @Option(name = "-files", required = true, usage = "One or more document paths whose category is " +
+                    "to be predicted by the model")
+            List<File> files;
+        }
+
+        Args args = new Args();
+        CmdLineParser parser = new CmdLineParser(args);
+        try {
+            parser.parseArgument(argss);
+        } catch (CmdLineException e) {
+            System.out.println(e.getMessage());
+            e.getParser().printUsage(System.out);
+            System.exit(1);
+        }
+
+        NeuralDocCatModel model = NeuralDocCatModel.loadModel(args.modelPath);
+        NeuralDocCat classifier = new NeuralDocCat(model);
+
+        System.out.println("Labels:" + model.getLabels());
+        Tokenizer tokenizer = WhitespaceTokenizer.INSTANCE;
+
+        for (File file: args.files) {
+            String text = FileUtils.readFileToString(file);
+            String[] tokens = tokenizer.tokenize(text.toLowerCase());
+            double[] probs = classifier.categorize(tokens);
+            System.out.println(">>" + file);
+            System.out.println("Probabilities:" + Arrays.toString(probs));
+        }
+
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
new file mode 100644
index 0000000..f1b6247
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
@@ -0,0 +1,179 @@
+package opennlp.tools.dl;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.*;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipInputStream;
+import java.util.zip.ZipOutputStream;
+
+/**
+ * This class is a wrapper for DL4J's {@link MultiLayerNetwork}, and {@link GlobalVectors}
+ * that provides features to serialize and deserialize necessary data to a zip file.
+ *
+ * This cane be used by a Neural Trainer tool to serialize the network and a predictor tool to restore the same network
+ * with the weights.
+ *
+ * <br/>
+ ** @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatModel {
+
+    public static final int VERSION = 1;
+    public static final String MODEL_NAME = NeuralDocCatModel.class.getName();
+    public static final String MANIFEST = "model.mf";
+    public static final String NETWORK = "network.json";
+    public static final String WEIGHTS = "weights.bin";
+    public static final String GLOVES = "gloves.tsv";
+    public static final String LABELS = "labels";
+    public static final String MAX_SEQ_LEN = "maxSeqLen";
+
+    private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatModel.class);
+
+    private final MultiLayerNetwork network;
+    private final GlobalVectors gloves;
+    private final Properties manifest;
+    private final List<String> labels;
+    private final int maxSeqLen;
+
+    /**
+     *
+     * @param stream Input stream of a Zip File
+     * @throws IOException
+     */
+    public NeuralDocCatModel(InputStream stream) throws IOException {
+        ZipInputStream zipIn = new ZipInputStream(stream);
+
+        Properties manifest = null;
+        MultiLayerNetwork model = null;
+        INDArray params = null;
+        GlobalVectors gloves = null;
+        ZipEntry entry;
+        while ((entry = zipIn.getNextEntry()) != null) {
+            String name = entry.getName();
+            switch (name) {
+                case MANIFEST:
+                    manifest = new Properties();
+                    manifest.load(zipIn);
+                    break;
+                case NETWORK:
+                    String json = IOUtils.toString(new UnclosableInputStream(zipIn));
+                    model = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
+                    break;
+                case WEIGHTS:
+                    params = Nd4j.read(new DataInputStream(new UnclosableInputStream(zipIn)));
+                    break;
+                case GLOVES:
+                    gloves = new GlobalVectors(new UnclosableInputStream(zipIn));
+                    break;
+                default:
+                    LOG.warn("Unexpected entry in the zip : {}", name);
+            }
+        }
+
+        assert model != null;
+        assert manifest != null;
+        model.init(params, false);
+        this.network = model;
+        this.manifest = manifest;
+        this.gloves = gloves;
+
+        assert manifest.containsKey(LABELS);
+        String[] labels = manifest.getProperty(LABELS).split(",");
+        this.labels = Collections.unmodifiableList(Arrays.asList(labels));
+
+        assert manifest.containsKey(MAX_SEQ_LEN);
+        this.maxSeqLen = Integer.parseInt(manifest.getProperty(MAX_SEQ_LEN));
+
+    }
+
+    /**
+     *
+     * @param network any compatible multi layer neural network
+     * @param vectors Global vectors
+     * @param labels list of labels
+     * @param maxSeqLen max sequence length
+     */
+    public NeuralDocCatModel(MultiLayerNetwork network, GlobalVectors vectors, List<String> labels, int maxSeqLen) {
+        this.network = network;
+        this.gloves = vectors;
+        this.manifest = new Properties();
+        this.manifest.setProperty(LABELS, StringUtils.join(labels, ","));
+        this.manifest.setProperty(MAX_SEQ_LEN, maxSeqLen + "");
+        this.labels = Collections.unmodifiableList(labels);
+        this.maxSeqLen = maxSeqLen;
+    }
+
+    public MultiLayerNetwork getNetwork() {
+        return network;
+    }
+
+    public GlobalVectors getGloves() {
+        return gloves;
+    }
+
+    public List<String> getLabels() {
+        return labels;
+    }
+
+    public int getMaxSeqLen() {
+        return this.maxSeqLen;
+    }
+
+    /**
+     * Zips the current state of the model and writes it stream
+     * @param stream stream to write
+     * @throws IOException
+     */
+    public void saveModel(OutputStream stream) throws IOException {
+        try (ZipOutputStream zipOut = new ZipOutputStream(new BufferedOutputStream(stream))) {
+            // Write out manifest
+            zipOut.putNextEntry(new ZipEntry(MANIFEST));
+
+            String comments = "Created-By:" + System.getenv("USER") + " at " + new Date().toString()
+                    + "\nModel-Version: " + VERSION
+                    + "\nModel-Schema:" + MODEL_NAME;
+
+            manifest.store(zipOut, comments);
+            zipOut.closeEntry();
+
+            // Write out the network
+            zipOut.putNextEntry(new ZipEntry(NETWORK));
+            byte[] jModel = network.getLayerWiseConfigurations().toJson().getBytes();
+            zipOut.write(jModel);
+            zipOut.closeEntry();
+
+            //Write out the network coefficients
+            zipOut.putNextEntry(new ZipEntry(WEIGHTS));
+            Nd4j.write(network.params(), new DataOutputStream(zipOut));
+            zipOut.closeEntry();
+
+            // Write out vectors
+            zipOut.putNextEntry(new ZipEntry(GLOVES));
+            gloves.writeOut(zipOut, false);
+            zipOut.closeEntry();
+
+            zipOut.finish();
+        }
+    }
+
+    /**
+     * creates a model from file on the local file system
+     * @param modelPath path to model file
+     * @return an instance of this class
+     * @throws IOException
+     */
+    public static NeuralDocCatModel loadModel(String modelPath) throws IOException {
+        try (InputStream modelStream = new FileInputStream(modelPath)) {
+            return new NeuralDocCatModel(modelStream);
+        }
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
new file mode 100644
index 0000000..7fe11ff
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
@@ -0,0 +1,258 @@
+package opennlp.tools.dl;
+
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.kohsuke.args4j.spi.StringArrayOptionHandler;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.List;
+
+
+/**
+ * This class provides functionality to construct and train neural networks that can be used for
+ * {@link opennlp.tools.doccat.DocumentCategorizer}
+ *
+ * @see NeuralDocCat
+ * @see NeuralDocCatModel
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatTrainer {
+
+    public static class Args {
+
+        @Option(name = "-batchSize", usage = "Number of examples in minibatch")
+        int batchSize = 128;
+
+        @Option(name = "-nEpochs", usage = "Number of epochs (i.e. full passes over the training data) to train on." +
+                " Applicable for training only.")
+        int nEpochs = 2;
+
+        @Option(name = "-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
+        int maxSeqLen = 256;    //Truncate text with length (# words) greater than this
+
+        @Option(name = "-vocabSize", usage = "Vocabulary Size.")
+        int vocabSize = 20000;   //vocabulary size
+
+        @Option(name = "-nRNNUnits", usage = "Number of RNN cells to use.")
+        int nRNNUnits = 128;
+
+        @Option(name = "-lr", aliases = "-learnRate", usage = "Learning Rate." +
+                " Adjust it when the scores bounce to NaN or Infinity.")
+        double learningRate = 2e-3;
+
+        @Option(name = "-glovesPath", required = true, usage = "Path to GloVe vectors file." +
+                " Download and unzip from https://nlp.stanford.edu/projects/glove/")
+        String glovesPath = null;
+
+        @Option(name = "-modelPath", required = true, usage = "Path to model file. " +
+                "This will be used for serializing the model after the training phase." )
+        String modelPath = null;
+
+        @Option(name = "-trainDir", required = true, usage = "Path to train data directory." +
+                " Setting this value will take the system to training mode. ")
+        String trainDir = null;
+
+        @Option(name = "-validDir", usage = "Path to validation data directory. Optional.")
+        String validDir = null;
+
+        @Option(name = "-labels", required = true, handler = StringArrayOptionHandler.class,
+                usage = "Names of targets or labels separated by spaces. " +
+                        "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
+                        "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
+                        "applicable. \n Example -labels pos neg")
+        List<String> labels = null;
+
+        @Override
+        public String toString() {
+            return "Args{" +
+                    "batchSize=" + batchSize +
+                    ", nEpochs=" + nEpochs +
+                    ", maxSeqLen=" + maxSeqLen +
+                    ", vocabSize=" + vocabSize +
+                    ", learningRate=" + learningRate +
+                    ", nRNNUnits=" + nRNNUnits +
+                    ", glovesPath='" + glovesPath + '\'' +
+                    ", modelPath='" + modelPath + '\'' +
+                    ", trainDir='" + trainDir + '\'' +
+                    ", validDir='" + validDir + '\'' +
+                    ", labels=" + labels +
+                    '}';
+        }
+    }
+
+    private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatTrainer.class);
+
+    private NeuralDocCatModel model;
+    private Args args;
+    private DataReader trainSet;
+    private DataReader validSet;
+
+
+    public NeuralDocCatTrainer(Args args) throws IOException {
+        this.args = args;
+        GlobalVectors gloves;
+        MultiLayerNetwork network;
+
+        try (InputStream stream = new FileInputStream(args.glovesPath)) {
+            gloves = new GlobalVectors(stream, args.vocabSize);
+        }
+
+        LOG.info("Training data from {}", args.trainDir);
+        this.trainSet = new DataReader(args.trainDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+        if (args.validDir != null) {
+            LOG.info("Validation data from {}", args.validDir);
+            this.validSet = new DataReader(args.validDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+        }
+
+        //create network
+        network = this.createNetwork(gloves.getVectorSize());
+        this.model = new NeuralDocCatModel(network, gloves, args.labels, args.maxSeqLen);
+    }
+
+    public MultiLayerNetwork createNetwork(int vectorSize) {
+        int totalOutcomes = this.trainSet.totalOutcomes();
+        assert totalOutcomes >= 2;
+        LOG.info("Number of classes " + totalOutcomes);
+
+        //TODO: the below network params should be configurable from CLI or settings file
+        //Set up network configuration
+        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+                .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
+                .rmsDecay(0.9)
+                .regularization(true).l2(1e-5)
+                .weightInit(WeightInit.XAVIER)
+                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+                .gradientNormalizationThreshold(1.0)
+                .learningRate(args.learningRate)
+                .list()
+                .layer(0, new GravesLSTM.Builder()
+                        .nIn(vectorSize)
+                        .nOut(args.nRNNUnits)
+                        .activation(Activation.RELU).build())
+                .layer(1, new RnnOutputLayer.Builder()
+                        .nIn(args.nRNNUnits)
+                        .nOut(totalOutcomes)
+                        .activation(Activation.SOFTMAX)
+                        .lossFunction(LossFunctions.LossFunction.MCXENT)
+                        .build())
+                .pretrain(false)
+                .backprop(true)
+                .build();
+
+        MultiLayerNetwork net = new MultiLayerNetwork(conf);
+        net.init();
+        net.setListeners(new ScoreIterationListener(1));
+        return net;
+    }
+
+    public void train() {
+        train(args.nEpochs, this.trainSet, this.validSet);
+    }
+
+    /**
+     * Trains model
+     *
+     * @param nEpochs    number of epochs (i.e. iterations over the training dataset)
+     * @param train      training data set
+     * @param validation validation data set for evaluation after each epoch.
+     *                   Setting this to null will skip the evaluation
+     */
+    public void train(int nEpochs, DataReader train, DataReader validation) {
+        assert model != null;
+        assert train != null;
+        LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
+                train.totalExamples(), validation == null ? null : validation.totalExamples());
+        for (int i = 0; i < nEpochs; i++) {
+            model.getNetwork().fit(train);
+            train.reset();
+            LOG.info("Epoch {} complete", i);
+
+            if (validation != null) {
+                LOG.info("Starting evaluation");
+                //Run evaluation. This is on 25k reviews, so can take some time
+                Evaluation evaluation = new Evaluation();
+                while (validation.hasNext()) {
+                    DataSet t = validation.next();
+                    INDArray features = t.getFeatureMatrix();
+                    INDArray labels = t.getLabels();
+                    INDArray inMask = t.getFeaturesMaskArray();
+                    INDArray outMask = t.getLabelsMaskArray();
+                    INDArray predicted = this.model.getNetwork().output(features, false, inMask, outMask);
+                    evaluation.evalTimeSeries(labels, predicted, outMask);
+                }
+                validation.reset();
+                LOG.info(evaluation.stats());
+            }
+        }
+    }
+
+    /**
+     * Saves the model to specified path
+     *
+     * @param path model path
+     * @throws IOException
+     */
+    public void saveModel(String path) throws IOException {
+        assert model != null;
+        LOG.info("Saving the model at {}", path);
+        try (OutputStream stream = new FileOutputStream(path)) {
+            model.saveModel(stream);
+        }
+    }
+
+    /**
+     * <pre>
+     *   # Download pre trained Glo-ves (this is a large file)
+     *   wget http://nlp.stanford.edu/data/glove.6B.zip
+     *   unzip glove.6B.zip -d glove.6B
+     *
+     *   # Download dataset
+     *   wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+     *   tar xzf aclImdb_v1.tar.gz
+     *
+     *  mvn compile exec:java
+     *    -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.NeuralDocCat
+     *    -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
+     *    -labels pos neg -modelPath imdb-sentiment-neural-model.zip
+     *    -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
+     *
+     * </pre>
+     */
+    public static void main(String[] argss) throws CmdLineException, IOException {
+        argss = ("-trainDir /Users/tg/work/datasets/aclImdb-tiny/train " +
+                "-glovesPath /Users/tg/work/datasets/glove.6B/glove.6B.50d.txt " +
+                "-labels pos neg " +
+                "-lr 0.1 -nEpochs 1 " +
+                "-modelPath neural-doc-cat.zip").split(" ");
+        Args args = new Args();
+        CmdLineParser parser = new CmdLineParser(args);
+        try {
+            parser.parseArgument(argss);
+        } catch (CmdLineException e) {
+            System.out.println(e.getMessage());
+            e.getParser().printUsage(System.out);
+            System.exit(1);
+        }
+        NeuralDocCatTrainer classifier = new NeuralDocCatTrainer(args);
+        classifier.train();
+        classifier.saveModel(args.modelPath);
+    }
+
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
new file mode 100644
index 0000000..701fc48
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
@@ -0,0 +1,56 @@
+package opennlp.tools.dl;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.Reader;
+import java.io.Writer;
+
+/**
+ * This class offers a wrapper for {@link InputStream};
+ * The only sole purpose of this wrapper is to bypass the close calls that are usually
+ * propagated from the readers.
+ * A use case of this wrapper is for reading multiple files from the {@link java.util.zip.ZipInputStream},
+ * especially because the tools like {@link org.apache.commons.io.IOUtils#copy(Reader, Writer)}
+ * and {@link org.nd4j.linalg.factory.Nd4j#read(InputStream)} automatically close the input stream.
+ *
+ * Note:
+ *  1. this tool ignores the call to {@link #close()} method
+ *  2. Remember to call {@link #forceClose()} when the stream when the inner stream needs to be closed
+ *  3. This wrapper doesn't hold any resources. If you close the innerStream, you can safely ignore closing this wrapper
+ *
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class UnclosableInputStream extends InputStream {
+
+    private InputStream innerStream;
+
+    public UnclosableInputStream(InputStream stream){
+        this.innerStream = stream;
+    }
+
+    @Override
+    public int read() throws IOException {
+        return innerStream.read();
+    }
+
+    /**
+     * NOP - Does not close the stream - intentional
+     * @throws IOException
+     */
+    @Override
+    public void close() throws IOException {
+        // intentionally ignored;
+        // Use forceClose() when needed to close
+    }
+
+    /**
+     * Closes the stream
+     * @throws IOException
+     */
+    public void forceClose() throws IOException {
+        if (innerStream != null) {
+            innerStream.close();
+            innerStream = null;
+        }
+    }
+}