text sequence classification using Glove and RNN/LSTMs
diff --git a/.gitignore b/.gitignore
index fe06e66..126d4a6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,6 @@
nbactions.xml
nb-configuration.xml
*.DS_Store
+
+.idea
+*.iml
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index 3d15d8f..9899efe 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -26,7 +26,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
- <nd4j.version>0.7.2</nd4j.version>
+ <nd4j.version>0.8.0</nd4j.version>
</properties>
<dependencies>
@@ -41,8 +41,6 @@
<artifactId>deeplearning4j-core</artifactId>
<version>${nd4j.version}</version>
</dependency>
-
-
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
@@ -64,6 +62,11 @@
<artifactId>nd4j-native-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
+ <dependency>
+ <groupId>args4j</groupId>
+ <artifactId>args4j</artifactId>
+ <version>2.33</version>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
new file mode 100644
index 0000000..825cf53
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.FileUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Function;
+
+/**
+ * This class provides a reader capable of reading training and test datasets from file system for text classifiers.
+ * In addition to reading the content, it
+ * (1) vectorizes the text using embeddings such as Glove, and
+ * (2) divides the datasets into mini batches of specified size.
+ *
+ * The data is expected to be organized as per the following convention:
+ * <pre>
+ * data-dir/
+ * +- label1 /
+ * | +- example11.txt
+ * | +- example12.txt
+ * | +- example13.txt
+ * | +- .....
+ * +- label2 /
+ * | +- example21.txt
+ * | +- .....
+ * +- labelN /
+ * +- exampleN1.txt
+ * +- .....
+ * </pre>
+ *
+ * In addition, the dataset shall be divided into training and testing as follows:
+ * <pre>
+ * data-dir/
+ * + train/
+ * | +- label1 /
+ * | +- labelN /
+ * + test /
+ * +- label1 /
+ * +- labelN /
+ * </pre>
+ *
+ * <h2>Usage: </h2>
+ * <code>
+ * // label names should match the subdirectory names
+ * labels = Arrays.asList("label1", "label2", ..."labelN");
+ * train = DataReader('data-dir/train', labels, embeds, ....);
+ * test = DataReader('data-dir/test', labels, embeds, ....)
+ * </code>
+ *
+ * @see GlobalVectors
+ * @see GloveRNNTextClassifier
+ * <br/>
+ * Contributors: Thamme Gowda (thammegowda@apache.org)
+ *
+ */
+public class DataReader implements DataSetIterator {
+
+ private static final Logger LOG = LoggerFactory.getLogger(DataReader.class);
+
+ private File dataDir;
+ private List<File> records;
+ private List<Integer> labels;
+ private Map<String, Integer> labelToId;
+ private String extension = ".txt";
+ private GlobalVectors embedder;
+ private int cursor = 0;
+ private int batchSize;
+ private int vectorLen;
+ private int maxSeqLen;
+ private int numLabels;
+ // default tokenizer
+ private Function<String, String[]> tokenizer = s -> s.toLowerCase().split(" ");
+
+
+ /**
+ * Creates a reader with the specified arguments
+ * @param dataDirPath data directory
+ * @param labelNames list of labels (names should match sub directory names)
+ * @param embedder embeddings to convert words to vectors
+ * @param batchSize mini batch size for DL4j training
+ * @param maxSeqLength truncate sequences that are longer than this.
+ * If truncation is not desired, set {@code Integer.MAX_VAL}
+ */
+ DataReader(String dataDirPath, List<String> labelNames, GlobalVectors embedder,
+ int batchSize, int maxSeqLength){
+ this.batchSize = batchSize;
+ this.embedder = embedder;
+ this.maxSeqLen = maxSeqLength;
+ this.vectorLen = embedder.getVectorSize();
+ this.numLabels = labelNames.size();
+ this.dataDir = new File(dataDirPath);
+ this.labelToId = new HashMap<>();
+ for (int i = 0; i < labelNames.size(); i++) {
+ labelToId.put(labelNames.get(i), i);
+ }
+ this.labelToId = Collections.unmodifiableMap(this.labelToId);
+ this.scanDir();
+ this.reset();
+ }
+
+ private void scanDir(){
+ assert dataDir.exists();
+ List<Integer> labels = new ArrayList<>();
+ List<File> files = new ArrayList<>();
+ for (String labelName: this.labelToId.keySet()) {
+ Integer labelId = this.labelToId.get(labelName);
+ assert labelId != null;
+ File labelDir = new File(dataDir, labelName);
+ if (!labelDir.exists()){
+ throw new IllegalStateException("No examples found for "
+ + labelName + ". Looked at:" + labelDir);
+ }
+ File[] examples = labelDir.listFiles(f ->
+ f.isFile() && f.getName().endsWith(this.extension));
+ if (examples == null || examples.length == 0){
+ throw new IllegalStateException("No examples found for "
+ + labelName + ". Looked at:" + labelDir
+ + " for files having extension: \" + extension");
+ }
+ LOG.info("Found {} examples for label {}", examples.length, labelName);
+ for (File example: examples) {
+ files.add(example);
+ labels.add(labelId);
+ }
+ }
+ this.records = files;
+ this.labels = labels;
+ }
+
+ /**
+ * sets tokenizer for converting text to tokens
+ * @param tokenizer tokenizer to use for converting text to tokens
+ */
+ public void setTokenizer(Function<String, String[]> tokenizer) {
+ this.tokenizer = tokenizer;
+ }
+
+ /**
+ * @return Tokenizer function used for converting text into words
+ */
+ public Function<String, String[]> getTokenizer() {
+ return tokenizer;
+ }
+
+ @Override
+ public DataSet next(int batchSize) {
+ batchSize = Math.min(batchSize, records.size() - cursor);
+ INDArray features = Nd4j.create(batchSize, vectorLen, maxSeqLen);
+ INDArray labels = Nd4j.create(batchSize, numLabels, maxSeqLen);
+
+ //Because we are dealing with text of different lengths and only one output at the final time step: use padding arrays
+ //Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
+ INDArray featuresMask = Nd4j.zeros(batchSize, maxSeqLen);
+ INDArray labelsMask = Nd4j.zeros(batchSize, maxSeqLen);
+
+ // Optimizations to speed up this code block by reusing memory
+ int _2dIndex[] = new int[2];
+ int _3dIndex[] = new int[3];
+ INDArrayIndex _3dNdIndex[] = new INDArrayIndex[]{null, NDArrayIndex.all(), null};
+
+ for (int i = 0; i < batchSize && cursor < records.size(); i++, cursor++) {
+ _2dIndex[0] = i;
+ _3dIndex[0] = i;
+ _3dNdIndex[0] = NDArrayIndex.point(i);
+
+ try {
+ // Read
+ File file = records.get(cursor);
+ int labelIdx = this.labels.get(cursor);
+ String text = FileUtils.readFileToString(file);
+ // Tokenize and Filter
+ String[] tokens = tokenizer.apply(text);
+ tokens = Arrays.stream(tokens).filter(embedder::hasWord).toArray(String[]::new);
+ //Get word vectors for each word in review, and put them in the training data
+ int j;
+ for(j = 0; j < tokens.length && j < maxSeqLen; j++ ){
+ String token = tokens[j];
+ INDArray vector = embedder.toVector(token);
+ _3dNdIndex[2] = NDArrayIndex.point(j);
+ features.put(_3dNdIndex, vector);
+ //Word is present (not padding) for this example + time step -> 1.0 in features mask
+ _2dIndex[1] = j;
+ featuresMask.putScalar(_2dIndex, 1.0);
+ }
+ int lastIdx = j - 1;
+ _2dIndex[1] = lastIdx;
+ _3dIndex[1] = labelIdx;
+ _3dIndex[2] = lastIdx;
+
+ labels.putScalar(_3dIndex,1.0); //Set label: one of k encoding
+ // Specify that an output exists at the final time step for this example
+ labelsMask.putScalar(_2dIndex,1.0);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ //LOG.info("Cursor = {} || Init Time = {}, Read time = {}, preprocess Time = {}, Mask Time={}", cursor, initTime, readTime, preProcTime, maskTime);
+ return new DataSet(features, labels, featuresMask, labelsMask);
+ }
+
+ @Override
+ public int totalExamples() {
+ return this.records.size();
+ }
+
+ @Override
+ public int inputColumns() {
+ return this.embedder.getVectorSize();
+ }
+
+ @Override
+ public int totalOutcomes() {
+ return this.numLabels;
+ }
+
+ @Override
+ public boolean resetSupported() {
+ return true;
+ }
+
+ @Override
+ public boolean asyncSupported() {
+ return false;
+ }
+
+ @Override
+ public void reset() {
+ assert this.records.size() == this.labels.size();
+ long seed = System.nanoTime(); // shuffle both the lists in the same order
+ Collections.shuffle(this.records, new Random(seed));
+ Collections.shuffle(this.labels, new Random(seed));
+ this.cursor = 0; // from beginning
+ }
+
+ @Override
+ public int batch() {
+ return this.batchSize;
+ }
+
+ @Override
+ public int cursor() {
+ return this.cursor;
+ }
+
+ @Override
+ public int numExamples() {
+ return totalExamples();
+ }
+
+ @Override
+ public void setPreProcessor(DataSetPreProcessor preProcessor) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataSetPreProcessor getPreProcessor() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List<String> getLabels() {
+ return new ArrayList<>(this.labelToId.keySet());
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < totalExamples() - 1;
+ }
+
+ @Override
+ public DataSet next() {
+ return next(this.batchSize);
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
new file mode 100644
index 0000000..bda4531
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * GlobalVectors (Glove) for projecting words to vector space.
+ * This tool utilizes word vectors pre-trained on large datasets.
+ *
+ * Visit https://nlp.stanford.edu/projects/glove/ for full documentation of Gloves.
+ *
+ * <h2>Usage</h2>
+ * <pre>
+ * path = "work/datasets/glove.6B/glove.6B.100d.txt";
+ * vocabSize = 20000; # max number of words to use
+ * GlobalVectors glove;
+ * try (InputStream stream = new FileInputStream(path)) {
+ * glove = new GlobalVectors(stream, vocabSize);
+ * }
+ * </pre>
+ *
+ */
+public class GlobalVectors {
+
+ private static final Logger LOG = LoggerFactory.getLogger(GlobalVectors.class);
+
+ private final INDArray embeddings;
+ private final Map<String, Integer> wordToId;
+ private final List<String> idToWord;
+ private final int vectorSize;
+ private final int maxWords;
+
+ /**
+ * Reads Global Vectors from stream
+ * @param stream Glove word vectors stream (plain text)
+ * @throws IOException
+ */
+ public GlobalVectors(InputStream stream) throws IOException {
+ this(stream, Integer.MAX_VALUE);
+ }
+
+ /**
+ *
+ * @param stream vector stream
+ * @param maxWords maximum number of words to use, i.e. vocabulary size
+ * @throws IOException
+ */
+ public GlobalVectors(InputStream stream, int maxWords) throws IOException {
+ List<String> words = new ArrayList<>();
+ List<INDArray> vectors = new ArrayList<>();
+ int vectorSize = -1;
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))){
+ String line;
+ while ((line = reader.readLine()) != null) {
+ String[] parts = line.split(" ");
+ if (vectorSize == -1) {
+ vectorSize = parts.length - 1;
+ } else {
+ assert vectorSize == parts.length - 1;
+ }
+ float[] vector = new float[vectorSize];
+ for (int i = 1; i < parts.length; i++) {
+ vector[i-1] = Float.parseFloat(parts[i]);
+ }
+ vectors.add(Nd4j.create(vector));
+ words.add(parts[0]);
+ if (words.size() >= maxWords) {
+ LOG.info("Max words reached at {}, aborting", words.size());
+ break;
+ }
+ }
+ LOG.info("Found {} words; Vector dimensions={}", words.size(), vectorSize);
+ this.vectorSize = vectorSize;
+ this.maxWords = Math.min(words.size(), maxWords);
+ this.embeddings = Nd4j.create(vectors, new int[]{vectors.size(), vectorSize});
+ this.idToWord = words;
+ this.wordToId = new HashMap<>();
+ for (int i = 0; i < words.size(); i++) {
+ wordToId.put(words.get(i), i);
+ }
+ }
+ }
+
+ /**
+ * @return size or dimensions of vectors
+ */
+ public int getVectorSize() {
+ return vectorSize;
+ }
+
+ public int getMaxWords() {
+ return maxWords;
+ }
+
+ /**
+ *
+ * @param word
+ * @return {@code true} if word is known; false otherwise
+ */
+ public boolean hasWord(String word){
+ return wordToId.containsKey(word);
+ }
+
+ /**
+ * Converts word to vectors
+ * @param word word to be converted to vector
+ * @return Vector if words exists or null otherwise
+ */
+ public INDArray toVector(String word){
+ if (wordToId.containsKey(word)){
+ return embeddings.getRow(wordToId.get(word));
+ }
+ return null;
+ }
+
+ public INDArray embed(String text, int maxLen){
+ return embed(text.toLowerCase().split(" "), maxLen);
+ }
+
+ public INDArray embed(String[] tokens, int maxLen){
+ List<String> tokensFiltered = new ArrayList<>();
+ for(String t: tokens ){
+ if(hasWord(t)){
+ tokensFiltered.add(t);
+ }
+ }
+ int seqLen = Math.min(maxLen, tokensFiltered.size());
+
+ INDArray features = Nd4j.create(1, vectorSize, seqLen);
+
+ for( int j = 0; j < seqLen; j++ ){
+ String token = tokensFiltered.get(j);
+ INDArray vector = toVector(token);
+ features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
+ }
+ return features;
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
new file mode 100644
index 0000000..dd7bcc5
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
@@ -0,0 +1,334 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import org.apache.commons.io.FileUtils;
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.util.ModelSerializer;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.kohsuke.args4j.spi.StringArrayOptionHandler;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * This is a Multi Class Text Classifier that uses Glove embeddings to vectorize the text
+ * and a LSTM RNN to classify the sequence of vectors.
+ *
+ * This class aimed to make a general purpose text classifier tool.
+ * A common use case would be to tune it for the text Sentiment classification task.
+ *
+ * The Glove Vectors can be downloaded from https://nlp.stanford.edu/projects/glove/
+ *
+ * <br/>
+ */
+public class GloveRNNTextClassifier {
+ private static final Logger LOG = LoggerFactory.getLogger(GloveRNNTextClassifier.class);
+
+ private MultiLayerNetwork model;
+ private GlobalVectors gloves;
+ private DataReader trainSet;
+ private DataReader validSet;
+
+ private Args args;
+
+ public GloveRNNTextClassifier(Args args) throws IOException {
+ this.init(args);
+ }
+
+ public static class Args {
+
+ @Option(name="-batchSize", depends = {"-trainDir"},
+ usage = "Number of examples in minibatch. Applicable for training only.")
+ int batchSize = 128;
+
+ @Option(name="-nEpochs", depends = {"-trainDir"},
+ usage = "Number of epochs (i.e. full passes over the training data) to train on." +
+ " Applicable for training only.")
+ int nEpochs = 2;
+
+ @Option(name="-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
+ int maxSeqLen = 256; //Truncate text with length (# words) greater than this
+
+ @Option(name="-vocabSize", usage = "Vocabulary Size.")
+ int vocabSize = 20000; //vocabulary size
+
+ @Option(name="-nRNNUnits", depends = {"-trainDir"},
+ usage = "Number of RNN cells to use. Applicable for training only.")
+ int nRNNUnits = 128;
+
+ @Option(name="-lr", aliases = "-learnRate", usage = "Learning Rate." +
+ " Adjust it when the scores bounce to NaN or Infinity.")
+ double learningRate = 2e-3;
+
+ @Option(name="-glovesPath", required = true, usage = "Path to GloVe vectors file." +
+ " Download and unzip from https://nlp.stanford.edu/projects/glove/")
+ String glovesPath = null;
+
+ @Option(name="-modelPath", required = true, usage = "Path to model file. " +
+ "This will be used for serializing the model after the training phase." +
+ "and also the model will be restored from here for prediction")
+ String modelPath = null;
+
+ @Option(name="-trainDir", usage = "Path to train data directory. Optional." +
+ " Setting this value will take the system to training mode. ")
+ String trainDir = null;
+
+ @Option(name="-validDir", depends = {"-trainDir"}, usage = "Path to validation data directory. Optional." +
+ " Applicable only when -trainDir is set.")
+ String validDir = null;
+
+ @Option(name="-labels", required = true, handler = StringArrayOptionHandler.class,
+ usage = "Names of targets or labels separated by spaces. " +
+ "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
+ "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
+ "applicable. \n Example -labels pos neg")
+ List<String> labels = null;
+
+ @Option(name="-files", handler = StringArrayOptionHandler.class,
+ usage = "File paths (separated by space) to predict using the model.")
+ List<String> filePaths = null;
+
+ @Override
+ public String toString() {
+ return "Args{" +
+ "batchSize=" + batchSize +
+ ", nEpochs=" + nEpochs +
+ ", maxSeqLen=" + maxSeqLen +
+ ", vocabSize=" + vocabSize +
+ ", learningRate=" + learningRate +
+ ", nRNNUnits=" + nRNNUnits +
+ ", glovesPath='" + glovesPath + '\'' +
+ ", modelPath='" + modelPath + '\'' +
+ ", trainDir='" + trainDir + '\'' +
+ ", validDir='" + validDir + '\'' +
+ ", labels=" + labels +
+ '}';
+ }
+ }
+
+
+ public void init(Args args) throws IOException {
+ this.args = args;
+ try (InputStream stream = new FileInputStream(args.glovesPath)) {
+ this.gloves = new GlobalVectors(stream, args.vocabSize);
+ }
+
+ if (args.trainDir != null) {
+ LOG.info("Training data from {}", args.trainDir);
+ this.trainSet = new DataReader(args.trainDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
+ if (args.validDir != null) {
+ LOG.info("Validation data from {}", args.validDir);
+ this.validSet = new DataReader(args.validDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
+ }
+ //create model
+ this.model = this.createModel();
+ // ready for training
+ } else {
+ //restore model
+ LOG.info("Training data not set => Going to restore model from {}", args.modelPath);
+ this.model = ModelSerializer.restoreMultiLayerNetwork(args.modelPath);
+ //ready for prediction
+ }
+ }
+
+ private MultiLayerNetwork createModel(){
+ int totalOutcomes = this.trainSet.totalOutcomes();
+ assert totalOutcomes >= 2;
+ LOG.info("Number of classes " + totalOutcomes);
+
+ //TODO: the below network params should be configurable from CLI or settings file
+ //Set up network configuration
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
+ .rmsDecay(0.9)
+ .regularization(true).l2(1e-5)
+ .weightInit(WeightInit.XAVIER)
+ .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+ .gradientNormalizationThreshold(1.0)
+ .learningRate(args.learningRate)
+ .list()
+ .layer(0, new GravesLSTM.Builder()
+ .nIn(gloves.getVectorSize())
+ .nOut(args.nRNNUnits)
+ .activation(Activation.RELU).build())
+ .layer(1, new RnnOutputLayer.Builder()
+ .nIn(args.nRNNUnits)
+ .nOut(totalOutcomes)
+ .activation(Activation.SOFTMAX)
+ .lossFunction(LossFunctions.LossFunction.MCXENT)
+ .build())
+ .pretrain(false)
+ .backprop(true)
+ .build();
+
+ MultiLayerNetwork net = new MultiLayerNetwork(conf);
+ net.init();
+ net.setListeners(new ScoreIterationListener(1));
+ return net;
+ }
+
+ public void train(){
+ train(args.nEpochs, this.trainSet, this.validSet);
+ }
+
+ /**
+ * Trains model
+ * @param nEpochs number of epochs (i.e. iterations over the training dataset)
+ * @param train training data set
+ * @param validation validation data set for evaluation after each epoch.
+ * Setting this to null will skip the evaluation
+ */
+ public void train(int nEpochs, DataReader train, DataReader validation){
+ assert model != null;
+ assert train != null;
+ LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
+ train.totalExamples(), validation == null ? null : validation.totalExamples());
+ for (int i = 0; i < nEpochs; i++) {
+ model.fit(train);
+ train.reset();
+ LOG.info("Epoch {} complete", i);
+
+ if (validation != null) {
+ LOG.info("Starting evaluation");
+ //Run evaluation. This is on 25k reviews, so can take some time
+ Evaluation evaluation = new Evaluation();
+ while (validation.hasNext()) {
+ DataSet t = validation.next();
+ INDArray features = t.getFeatureMatrix();
+ INDArray labels = t.getLabels();
+ INDArray inMask = t.getFeaturesMaskArray();
+ INDArray outMask = t.getLabelsMaskArray();
+ INDArray predicted = model.output(features, false, inMask, outMask);
+ evaluation.evalTimeSeries(labels, predicted, outMask);
+ }
+ validation.reset();
+ LOG.info(evaluation.stats());
+ }
+ }
+ }
+
+ /**
+ * Saves the model to specified path
+ * @param path model path
+ * @throws IOException
+ */
+ public void saveModel(String path) throws IOException {
+ assert model != null;
+ LOG.info("Saving the model at {}", path);
+ ModelSerializer.writeModel(model, path, true);
+ }
+
+ /**
+ * Predicts class probability of the text based on the state of model
+ * @param text text to be classified
+ * @return array of doubles, indices associated with indices of <code>this.args.labels</code>
+ */
+ public double[] predict(String text){
+
+ INDArray seqFeatures = this.gloves.embed(text, this.args.maxSeqLen);
+ INDArray networkOutput = this.model.output(seqFeatures);
+ int timeSeriesLength = networkOutput.size(2);
+ INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
+ NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
+
+ double[] probs = new double[args.labels.size()];
+ for (int i = 0; i < args.labels.size(); i++) {
+ probs[i] = probsAtLastWord.getDouble(i);
+ }
+ return probs;
+ }
+
+ /**
+ * <pre>
+ * # Download pre trained Glo-ves (this is a large file)
+ * wget http://nlp.stanford.edu/data/glove.6B.zip
+ * unzip glove.6B.zip -d glove.6B
+ *
+ * # Download dataset
+ * wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+ * tar xzf aclImdb_v1.tar.gz
+ *
+ * mvn compile exec:java
+ * -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.GloveRNNTextClassifier
+ * -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
+ * -labels pos neg -modelPath imdb-sentimodel.dat
+ * -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
+ *
+ * </pre>
+ *
+ */
+ public static void main(String[] argss) throws CmdLineException, IOException {
+ Args args = new Args();
+ CmdLineParser parser = new CmdLineParser(args);
+ try {
+ parser.parseArgument(argss);
+ } catch (CmdLineException e) {
+ System.out.println(e.getMessage());
+ e.getParser().printUsage(System.out);
+ System.exit(1);
+ }
+ GloveRNNTextClassifier classifier = new GloveRNNTextClassifier(args);
+ byte numOps = 0;
+ if (classifier.trainSet != null) {
+ numOps++;
+ classifier.train();
+ classifier.saveModel(args.modelPath);
+ }
+
+ if (args.filePaths != null && !args.filePaths.isEmpty()) {
+ numOps++;
+ System.out.println("Labels:" + args.labels);
+ for (String filePath: args.filePaths) {
+ File file = new File(filePath);
+ String text = FileUtils.readFileToString(file);
+ double[] probs = classifier.predict(text);
+ System.out.println(">>" + filePath);
+ System.out.println("Probabilities:" + Arrays.toString(probs));
+ }
+ }
+
+ if (numOps == 0) {
+ System.out.println("Provide -trainDir to train a model, -files to classify files");
+ System.exit(2);
+ }
+ }
+}