text sequence classification using Glove and RNN/LSTMs
diff --git a/.gitignore b/.gitignore
index fe06e66..126d4a6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,6 @@
 nbactions.xml
 nb-configuration.xml
 *.DS_Store
+
+.idea
+*.iml
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index 3d15d8f..9899efe 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -26,7 +26,7 @@
 
   <properties>
     <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
-    <nd4j.version>0.7.2</nd4j.version>
+    <nd4j.version>0.8.0</nd4j.version>
   </properties>
 
   <dependencies>
@@ -41,8 +41,6 @@
           <artifactId>deeplearning4j-core</artifactId>
           <version>${nd4j.version}</version>
       </dependency>
-
-
       <dependency>
           <groupId>org.deeplearning4j</groupId>
           <artifactId>deeplearning4j-nlp</artifactId>
@@ -64,6 +62,11 @@
       <artifactId>nd4j-native-platform</artifactId>
       <version>${nd4j.version}</version>
     </dependency>
+    <dependency>
+      <groupId>args4j</groupId>
+      <artifactId>args4j</artifactId>
+      <version>2.33</version>
+    </dependency>
   </dependencies>
   <build>
     <plugins>
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
new file mode 100644
index 0000000..825cf53
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.FileUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Function;
+
+/**
+ * This class provides a reader capable of reading training and test datasets from file system for text classifiers.
+ * In addition to reading the content, it
+ * (1) vectorizes the text using embeddings such as Glove, and
+ * (2) divides the datasets into mini batches of specified size.
+ *
+ * The data is expected to be organized as per the following convention:
+ * <pre>
+ * data-dir/
+ *     +- label1 /
+ *     |    +- example11.txt
+ *     |    +- example12.txt
+ *     |    +- example13.txt
+ *     |    +- .....
+ *     +- label2 /
+ *     |    +- example21.txt
+ *     |    +- .....
+ *     +- labelN /
+ *          +- exampleN1.txt
+ *          +- .....
+ * </pre>
+ *
+ * In addition, the dataset shall be divided into training and testing as follows:
+ * <pre>
+ * data-dir/
+ *     + train/
+ *     |   +- label1 /
+ *     |   +- labelN /
+ *     + test /
+ *         +- label1 /
+ *         +- labelN /
+ * </pre>
+ *
+ * <h2>Usage: </h2>
+ * <code>
+ *     // label names should match the subdirectory names
+ *     labels = Arrays.asList("label1", "label2", ..."labelN");
+ *     train = DataReader('data-dir/train', labels, embeds, ....);
+ *     test = DataReader('data-dir/test', labels, embeds, ....)
+ * </code>
+ *
+ * @see GlobalVectors
+ * @see GloveRNNTextClassifier
+ * <br/>
+ * Contributors: Thamme Gowda (thammegowda@apache.org)
+ *
+ */
+public class DataReader implements DataSetIterator {
+
+    private static final Logger LOG = LoggerFactory.getLogger(DataReader.class);
+
+    private File dataDir;
+    private List<File> records;
+    private List<Integer> labels;
+    private Map<String, Integer> labelToId;
+    private String extension = ".txt";
+    private GlobalVectors embedder;
+    private int cursor = 0;
+    private int batchSize;
+    private int vectorLen;
+    private int maxSeqLen;
+    private int numLabels;
+    // default tokenizer
+    private Function<String, String[]> tokenizer = s -> s.toLowerCase().split(" ");
+
+
+    /**
+     * Creates a reader with the specified arguments
+     * @param dataDirPath data directory
+     * @param labelNames list of labels (names should match sub directory names)
+     * @param embedder embeddings to convert words to vectors
+     * @param batchSize mini batch size for DL4j training
+     * @param maxSeqLength truncate sequences that are longer than this.
+     *                    If truncation is not desired, set {@code Integer.MAX_VAL}
+     */
+    DataReader(String dataDirPath, List<String> labelNames, GlobalVectors embedder,
+               int batchSize, int maxSeqLength){
+        this.batchSize = batchSize;
+        this.embedder = embedder;
+        this.maxSeqLen = maxSeqLength;
+        this.vectorLen = embedder.getVectorSize();
+        this.numLabels = labelNames.size();
+        this.dataDir = new File(dataDirPath);
+        this.labelToId = new HashMap<>();
+        for (int i = 0; i < labelNames.size(); i++) {
+            labelToId.put(labelNames.get(i), i);
+        }
+        this.labelToId = Collections.unmodifiableMap(this.labelToId);
+        this.scanDir();
+        this.reset();
+    }
+
+    private void scanDir(){
+        assert dataDir.exists();
+        List<Integer> labels = new ArrayList<>();
+        List<File> files = new ArrayList<>();
+        for (String labelName: this.labelToId.keySet()) {
+            Integer labelId = this.labelToId.get(labelName);
+            assert labelId != null;
+            File labelDir = new File(dataDir, labelName);
+            if (!labelDir.exists()){
+                throw new IllegalStateException("No examples found for "
+                        + labelName + ". Looked at:" + labelDir);
+            }
+            File[] examples = labelDir.listFiles(f ->
+                    f.isFile() && f.getName().endsWith(this.extension));
+            if (examples == null || examples.length == 0){
+                throw new IllegalStateException("No examples found for "
+                        + labelName + ". Looked at:" + labelDir
+                        + " for files having extension: \" + extension");
+            }
+            LOG.info("Found {} examples for label {}", examples.length, labelName);
+            for (File example: examples) {
+                files.add(example);
+                labels.add(labelId);
+            }
+        }
+        this.records = files;
+        this.labels = labels;
+    }
+
+    /**
+     * sets tokenizer for converting text to tokens
+     * @param tokenizer tokenizer to use for converting text to tokens
+     */
+    public void setTokenizer(Function<String, String[]> tokenizer) {
+        this.tokenizer = tokenizer;
+    }
+
+    /**
+     * @return Tokenizer function used for converting text into words
+     */
+    public Function<String, String[]> getTokenizer() {
+        return tokenizer;
+    }
+
+    @Override
+    public DataSet next(int batchSize) {
+        batchSize = Math.min(batchSize, records.size() - cursor);
+        INDArray features = Nd4j.create(batchSize, vectorLen, maxSeqLen);
+        INDArray labels = Nd4j.create(batchSize, numLabels, maxSeqLen);
+
+        //Because we are dealing with text of different lengths and only one output at the final time step: use padding arrays
+        //Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
+        INDArray featuresMask = Nd4j.zeros(batchSize, maxSeqLen);
+        INDArray labelsMask = Nd4j.zeros(batchSize, maxSeqLen);
+
+        // Optimizations to speed up this code block by reusing memory
+        int _2dIndex[] = new int[2];
+        int _3dIndex[] = new int[3];
+        INDArrayIndex _3dNdIndex[] = new INDArrayIndex[]{null, NDArrayIndex.all(), null};
+
+        for (int i = 0; i < batchSize && cursor < records.size(); i++, cursor++) {
+            _2dIndex[0] = i;
+            _3dIndex[0] = i;
+            _3dNdIndex[0] = NDArrayIndex.point(i);
+
+            try {
+                // Read
+                File file = records.get(cursor);
+                int labelIdx = this.labels.get(cursor);
+                String text = FileUtils.readFileToString(file);
+                // Tokenize and Filter
+                String[] tokens = tokenizer.apply(text);
+                tokens = Arrays.stream(tokens).filter(embedder::hasWord).toArray(String[]::new);
+                //Get word vectors for each word in review, and put them in the training data
+                int j;
+                for(j = 0; j < tokens.length && j < maxSeqLen; j++ ){
+                    String token = tokens[j];
+                    INDArray vector = embedder.toVector(token);
+                    _3dNdIndex[2] = NDArrayIndex.point(j);
+                    features.put(_3dNdIndex, vector);
+                    //Word is present (not padding) for this example + time step -> 1.0 in features mask
+                    _2dIndex[1] = j;
+                    featuresMask.putScalar(_2dIndex, 1.0);
+                }
+                int lastIdx = j - 1;
+                _2dIndex[1] = lastIdx;
+                _3dIndex[1] = labelIdx;
+                _3dIndex[2] = lastIdx;
+
+                labels.putScalar(_3dIndex,1.0);   //Set label: one of k encoding
+                // Specify that an output exists at the final time step for this example
+                labelsMask.putScalar(_2dIndex,1.0);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+        //LOG.info("Cursor = {} || Init Time = {}, Read time = {}, preprocess Time = {}, Mask Time={}", cursor, initTime, readTime, preProcTime, maskTime);
+        return new DataSet(features, labels, featuresMask, labelsMask);
+    }
+
+    @Override
+    public int totalExamples() {
+        return this.records.size();
+    }
+
+    @Override
+    public int inputColumns() {
+        return this.embedder.getVectorSize();
+    }
+
+    @Override
+    public int totalOutcomes() {
+        return this.numLabels;
+    }
+
+    @Override
+    public boolean resetSupported() {
+        return true;
+    }
+
+    @Override
+    public boolean asyncSupported() {
+        return false;
+    }
+
+    @Override
+    public void reset() {
+        assert this.records.size() == this.labels.size();
+        long seed = System.nanoTime(); // shuffle both the lists in the same order
+        Collections.shuffle(this.records, new Random(seed));
+        Collections.shuffle(this.labels, new Random(seed));
+        this.cursor = 0; // from beginning
+    }
+
+    @Override
+    public int batch() {
+        return this.batchSize;
+    }
+
+    @Override
+    public int cursor() {
+        return this.cursor;
+    }
+
+    @Override
+    public int numExamples() {
+        return totalExamples();
+    }
+
+    @Override
+    public void setPreProcessor(DataSetPreProcessor preProcessor) {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public DataSetPreProcessor getPreProcessor() {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public List<String> getLabels() {
+        return new ArrayList<>(this.labelToId.keySet());
+    }
+
+    @Override
+    public boolean hasNext() {
+        return cursor < totalExamples() - 1;
+    }
+
+    @Override
+    public DataSet next() {
+        return next(this.batchSize);
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
new file mode 100644
index 0000000..bda4531
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * GlobalVectors (Glove) for projecting words to vector space.
+ * This tool utilizes word vectors  pre-trained on large datasets.
+ *
+ * Visit https://nlp.stanford.edu/projects/glove/ for full documentation of Gloves.
+ *
+ * <h2>Usage</h2>
+ * <pre>
+ * path = "work/datasets/glove.6B/glove.6B.100d.txt";
+ * vocabSize = 20000; # max number of words to use
+ * GlobalVectors glove;
+ * try (InputStream stream = new FileInputStream(path)) {
+ *    glove = new GlobalVectors(stream, vocabSize);
+ * }
+ * </pre>
+ *
+ */
+public class GlobalVectors {
+
+    private static final Logger LOG = LoggerFactory.getLogger(GlobalVectors.class);
+
+    private final INDArray embeddings;
+    private final Map<String, Integer> wordToId;
+    private final List<String> idToWord;
+    private final int vectorSize;
+    private final int maxWords;
+
+    /**
+     * Reads Global Vectors from stream
+     * @param stream Glove word vectors stream (plain text)
+     * @throws IOException
+     */
+    public GlobalVectors(InputStream stream) throws IOException {
+        this(stream, Integer.MAX_VALUE);
+    }
+
+    /**
+     *
+     * @param stream vector stream
+     * @param maxWords maximum number of words to use, i.e. vocabulary size
+     * @throws IOException
+     */
+    public GlobalVectors(InputStream stream, int maxWords) throws IOException {
+        List<String> words = new ArrayList<>();
+        List<INDArray> vectors = new ArrayList<>();
+        int vectorSize = -1;
+        try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))){
+            String line;
+            while ((line = reader.readLine()) != null) {
+                String[] parts = line.split(" ");
+                if (vectorSize == -1) {
+                    vectorSize = parts.length - 1;
+                } else {
+                    assert vectorSize == parts.length - 1;
+                }
+                float[] vector = new float[vectorSize];
+                for (int i = 1; i < parts.length; i++) {
+                    vector[i-1] = Float.parseFloat(parts[i]);
+                }
+                vectors.add(Nd4j.create(vector));
+                words.add(parts[0]);
+                if (words.size() >= maxWords) {
+                    LOG.info("Max words reached at {}, aborting", words.size());
+                    break;
+                }
+            }
+            LOG.info("Found {} words; Vector dimensions={}", words.size(), vectorSize);
+            this.vectorSize = vectorSize;
+            this.maxWords = Math.min(words.size(), maxWords);
+            this.embeddings = Nd4j.create(vectors, new int[]{vectors.size(), vectorSize});
+            this.idToWord = words;
+            this.wordToId = new HashMap<>();
+            for (int i = 0; i < words.size(); i++) {
+                wordToId.put(words.get(i), i);
+            }
+        }
+    }
+
+    /**
+     * @return size or dimensions of vectors
+     */
+    public int getVectorSize() {
+        return vectorSize;
+    }
+
+    public int getMaxWords() {
+        return maxWords;
+    }
+
+    /**
+     *
+     * @param word
+     * @return {@code true} if word is known; false otherwise
+     */
+    public boolean hasWord(String word){
+        return wordToId.containsKey(word);
+    }
+
+    /**
+     * Converts word to vectors
+     * @param word word to be converted to vector
+     * @return Vector if words exists or null otherwise
+     */
+    public INDArray toVector(String word){
+        if (wordToId.containsKey(word)){
+            return embeddings.getRow(wordToId.get(word));
+        }
+        return null;
+    }
+
+    public INDArray embed(String text, int maxLen){
+        return embed(text.toLowerCase().split(" "), maxLen);
+    }
+
+    public INDArray embed(String[] tokens, int maxLen){
+        List<String> tokensFiltered = new ArrayList<>();
+        for(String t: tokens ){
+            if(hasWord(t)){
+                tokensFiltered.add(t);
+            }
+        }
+        int seqLen = Math.min(maxLen, tokensFiltered.size());
+
+        INDArray features = Nd4j.create(1, vectorSize, seqLen);
+
+        for( int j = 0; j < seqLen; j++ ){
+            String token = tokensFiltered.get(j);
+            INDArray vector = toVector(token);
+            features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
+        }
+        return features;
+    }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
new file mode 100644
index 0000000..dd7bcc5
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
@@ -0,0 +1,334 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.FileUtils;
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.util.ModelSerializer;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.kohsuke.args4j.spi.StringArrayOptionHandler;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * This is a Multi Class Text Classifier that uses Glove embeddings to vectorize the text
+ * and a LSTM RNN to classify the sequence of vectors.
+ *
+ * This class aimed to make a general purpose text classifier tool.
+ * A common use case would be to tune it for the text Sentiment classification task.
+ *
+ * The Glove Vectors can be downloaded from https://nlp.stanford.edu/projects/glove/
+ *
+ * <br/>
+ */
+public class GloveRNNTextClassifier {
+    private static final Logger LOG = LoggerFactory.getLogger(GloveRNNTextClassifier.class);
+
+    private MultiLayerNetwork model;
+    private GlobalVectors gloves;
+    private DataReader trainSet;
+    private DataReader validSet;
+
+    private Args args;
+
+    public GloveRNNTextClassifier(Args args) throws IOException {
+        this.init(args);
+    }
+
+    public static class Args {
+
+        @Option(name="-batchSize", depends = {"-trainDir"},
+                usage = "Number of examples in minibatch. Applicable for training only.")
+        int batchSize = 128;
+
+        @Option(name="-nEpochs", depends = {"-trainDir"},
+                usage = "Number of epochs (i.e. full passes over the training data) to train on." +
+                " Applicable for training only.")
+        int nEpochs = 2;
+
+        @Option(name="-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
+        int maxSeqLen = 256;    //Truncate text with length (# words) greater than this
+
+        @Option(name="-vocabSize", usage = "Vocabulary Size.")
+        int vocabSize = 20000;   //vocabulary size
+
+        @Option(name="-nRNNUnits", depends = {"-trainDir"},
+                usage = "Number of RNN cells to use. Applicable for training only.")
+        int nRNNUnits = 128;
+
+        @Option(name="-lr", aliases = "-learnRate", usage = "Learning Rate." +
+                " Adjust it when the scores bounce to NaN or Infinity.")
+        double learningRate = 2e-3;
+
+        @Option(name="-glovesPath", required = true, usage = "Path to GloVe vectors file." +
+                " Download and unzip from https://nlp.stanford.edu/projects/glove/")
+        String glovesPath = null;
+
+        @Option(name="-modelPath", required = true, usage = "Path to model file. " +
+                "This will be used for serializing the model after the training phase." +
+                "and also the model will be restored from here for prediction")
+        String modelPath = null;
+
+        @Option(name="-trainDir", usage = "Path to train data directory. Optional." +
+                " Setting this value will take the system to training mode. ")
+        String trainDir = null;
+
+        @Option(name="-validDir", depends = {"-trainDir"}, usage = "Path to validation data directory. Optional." +
+                " Applicable only when -trainDir is set.")
+        String validDir = null;
+
+        @Option(name="-labels", required = true, handler = StringArrayOptionHandler.class,
+                usage = "Names of targets or labels separated by spaces. " +
+                "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
+                "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
+                        "applicable. \n Example -labels pos neg")
+        List<String> labels = null;
+
+        @Option(name="-files", handler = StringArrayOptionHandler.class,
+                usage = "File paths (separated by space) to predict using the model.")
+        List<String> filePaths = null;
+
+        @Override
+        public String toString() {
+            return "Args{" +
+                    "batchSize=" + batchSize +
+                    ", nEpochs=" + nEpochs +
+                    ", maxSeqLen=" + maxSeqLen +
+                    ", vocabSize=" + vocabSize +
+                    ", learningRate=" + learningRate +
+                    ", nRNNUnits=" + nRNNUnits +
+                    ", glovesPath='" + glovesPath + '\'' +
+                    ", modelPath='" + modelPath + '\'' +
+                    ", trainDir='" + trainDir + '\'' +
+                    ", validDir='" + validDir + '\'' +
+                    ", labels=" + labels +
+                    '}';
+        }
+    }
+
+
+    public void init(Args args) throws IOException {
+        this.args = args;
+        try (InputStream stream = new FileInputStream(args.glovesPath)) {
+            this.gloves = new GlobalVectors(stream, args.vocabSize);
+        }
+
+        if (args.trainDir != null) {
+            LOG.info("Training data from {}", args.trainDir);
+            this.trainSet = new DataReader(args.trainDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
+            if (args.validDir != null) {
+                LOG.info("Validation data from {}", args.validDir);
+                this.validSet = new DataReader(args.validDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
+            }
+            //create model
+            this.model = this.createModel();
+            // ready for training
+        } else {
+            //restore model
+            LOG.info("Training data not set => Going to restore model from {}", args.modelPath);
+            this.model = ModelSerializer.restoreMultiLayerNetwork(args.modelPath);
+            //ready for prediction
+        }
+    }
+
+    private MultiLayerNetwork createModel(){
+        int totalOutcomes = this.trainSet.totalOutcomes();
+        assert totalOutcomes >= 2;
+        LOG.info("Number of classes " + totalOutcomes);
+
+        //TODO: the below network params should be configurable from CLI or settings file
+        //Set up network configuration
+        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+                .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
+                .rmsDecay(0.9)
+                .regularization(true).l2(1e-5)
+                .weightInit(WeightInit.XAVIER)
+                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+                .gradientNormalizationThreshold(1.0)
+                .learningRate(args.learningRate)
+                .list()
+                .layer(0, new GravesLSTM.Builder()
+                        .nIn(gloves.getVectorSize())
+                        .nOut(args.nRNNUnits)
+                        .activation(Activation.RELU).build())
+                .layer(1, new RnnOutputLayer.Builder()
+                        .nIn(args.nRNNUnits)
+                        .nOut(totalOutcomes)
+                        .activation(Activation.SOFTMAX)
+                        .lossFunction(LossFunctions.LossFunction.MCXENT)
+                        .build())
+                .pretrain(false)
+                .backprop(true)
+                .build();
+
+        MultiLayerNetwork net = new MultiLayerNetwork(conf);
+        net.init();
+        net.setListeners(new ScoreIterationListener(1));
+        return net;
+    }
+
+    public void train(){
+        train(args.nEpochs, this.trainSet, this.validSet);
+    }
+
+    /**
+     * Trains model
+     * @param nEpochs number of epochs (i.e. iterations over the training dataset)
+     * @param train training data set
+     * @param validation validation data set for evaluation after each epoch.
+     *                  Setting this to null will skip the evaluation
+     */
+    public void train(int nEpochs, DataReader train, DataReader validation){
+        assert model != null;
+        assert train != null;
+        LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
+                train.totalExamples(), validation == null ? null : validation.totalExamples());
+        for (int i = 0; i < nEpochs; i++) {
+            model.fit(train);
+            train.reset();
+            LOG.info("Epoch {} complete", i);
+
+            if (validation != null) {
+                LOG.info("Starting evaluation");
+                //Run evaluation. This is on 25k reviews, so can take some time
+                Evaluation evaluation = new Evaluation();
+                while (validation.hasNext()) {
+                    DataSet t = validation.next();
+                    INDArray features = t.getFeatureMatrix();
+                    INDArray labels = t.getLabels();
+                    INDArray inMask = t.getFeaturesMaskArray();
+                    INDArray outMask = t.getLabelsMaskArray();
+                    INDArray predicted = model.output(features, false, inMask, outMask);
+                    evaluation.evalTimeSeries(labels, predicted, outMask);
+                }
+                validation.reset();
+                LOG.info(evaluation.stats());
+            }
+        }
+    }
+
+    /**
+     * Saves the model to specified path
+     * @param path model path
+     * @throws IOException
+     */
+    public void saveModel(String path) throws IOException {
+        assert model != null;
+        LOG.info("Saving the model at {}", path);
+        ModelSerializer.writeModel(model, path, true);
+    }
+
+    /**
+     * Predicts class probability of the text based on the state of model
+     * @param text text to be classified
+     * @return array of doubles, indices associated with indices of <code>this.args.labels</code>
+     */
+    public double[] predict(String text){
+
+        INDArray seqFeatures = this.gloves.embed(text, this.args.maxSeqLen);
+        INDArray networkOutput = this.model.output(seqFeatures);
+        int timeSeriesLength = networkOutput.size(2);
+        INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
+                NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
+
+        double[] probs = new double[args.labels.size()];
+        for (int i = 0; i < args.labels.size(); i++) {
+            probs[i] = probsAtLastWord.getDouble(i);
+        }
+        return probs;
+    }
+
+    /**
+     * <pre>
+     *   # Download pre trained Glo-ves (this is a large file)
+     *   wget http://nlp.stanford.edu/data/glove.6B.zip
+     *   unzip glove.6B.zip -d glove.6B
+     *
+     *   # Download dataset
+     *   wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+     *   tar xzf aclImdb_v1.tar.gz
+     *
+     *  mvn compile exec:java
+     *    -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.GloveRNNTextClassifier
+     *    -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
+     *    -labels pos neg -modelPath imdb-sentimodel.dat
+     *    -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
+     *
+     * </pre>
+     *
+     */
+    public static void main(String[] argss) throws CmdLineException, IOException {
+        Args args = new Args();
+        CmdLineParser parser = new CmdLineParser(args);
+        try {
+            parser.parseArgument(argss);
+        } catch (CmdLineException e) {
+            System.out.println(e.getMessage());
+            e.getParser().printUsage(System.out);
+            System.exit(1);
+        }
+        GloveRNNTextClassifier classifier = new GloveRNNTextClassifier(args);
+        byte numOps = 0;
+        if (classifier.trainSet != null) {
+            numOps++;
+            classifier.train();
+            classifier.saveModel(args.modelPath);
+        }
+
+        if (args.filePaths != null && !args.filePaths.isEmpty()) {
+            numOps++;
+            System.out.println("Labels:" + args.labels);
+            for (String filePath: args.filePaths) {
+                File file = new File(filePath);
+                String text = FileUtils.readFileToString(file);
+                double[] probs = classifier.predict(text);
+                System.out.println(">>" + filePath);
+                System.out.println("Probabilities:" + Arrays.toString(probs));
+            }
+        }
+
+        if (numOps == 0) {
+            System.out.println("Provide -trainDir to train a model, -files to classify files");
+            System.exit(2);
+        }
+    }
+}