Refactored and implemented DocCat API
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
index 825cf53..86af123 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/DataReader.java
@@ -82,9 +82,9 @@
* </code>
*
* @see GlobalVectors
- * @see GloveRNNTextClassifier
+ * @see NeuralDocCat
* <br/>
- * Contributors: Thamme Gowda (thammegowda@apache.org)
+ * @author Thamme Gowda (thammegowda@apache.org)
*
*/
public class DataReader implements DataSetIterator {
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
index bda4531..fdf3a95 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/GlobalVectors.java
@@ -18,6 +18,7 @@
*/
package opennlp.tools.dl;
+import org.apache.commons.io.IOUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@@ -26,11 +27,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.BufferedReader;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
+import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -53,6 +50,8 @@
* }
* </pre>
*
+ * @author Thamme Gowda (thammegowda@apache.org)
+ *
*/
public class GlobalVectors {
@@ -169,4 +168,32 @@
}
return features;
}
+
+ public void writeOut(OutputStream stream, boolean closeStream) throws IOException {
+ writeOut(stream, "%.5f", closeStream);
+ }
+
+ public void writeOut(OutputStream stream,
+ String floatPrecisionFormatString, boolean closeStream) throws IOException {
+ if (!Character.isWhitespace(floatPrecisionFormatString.charAt(0))) {
+ floatPrecisionFormatString = " " + floatPrecisionFormatString;
+ }
+ LOG.info("Writing {} vectors out, float precision {}", idToWord.size(), floatPrecisionFormatString);
+
+ PrintWriter out = new PrintWriter(stream);
+ try {
+ for (int i = 0; i < idToWord.size(); i++) {
+ out.printf("%s", idToWord.get(i));
+ INDArray row = embeddings.getRow(i);
+ for (int j = 0; j < vectorSize; j++) {
+ out.printf(floatPrecisionFormatString, row.getDouble(j));
+ }
+ out.println();
+ }
+ } finally {
+ if (closeStream){
+ IOUtils.closeQuietly(out);
+ } // else dont close because, closing the print writer also closes the inner stream
+ }
+ }
}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java b/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
deleted file mode 100644
index dd7bcc5..0000000
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/GloveRNNTextClassifier.java
+++ /dev/null
@@ -1,334 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package opennlp.tools.dl;
-
-import org.apache.commons.io.FileUtils;
-import org.deeplearning4j.eval.Evaluation;
-import org.deeplearning4j.nn.conf.GradientNormalization;
-import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
-import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
-import org.deeplearning4j.nn.conf.Updater;
-import org.deeplearning4j.nn.conf.layers.GravesLSTM;
-import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
-import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
-import org.deeplearning4j.nn.weights.WeightInit;
-import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
-import org.deeplearning4j.util.ModelSerializer;
-import org.kohsuke.args4j.CmdLineException;
-import org.kohsuke.args4j.CmdLineParser;
-import org.kohsuke.args4j.Option;
-import org.kohsuke.args4j.spi.StringArrayOptionHandler;
-import org.nd4j.linalg.activations.Activation;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.dataset.DataSet;
-import org.nd4j.linalg.indexing.NDArrayIndex;
-import org.nd4j.linalg.lossfunctions.LossFunctions;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.util.Arrays;
-import java.util.List;
-
-/**
- * This is a Multi Class Text Classifier that uses Glove embeddings to vectorize the text
- * and a LSTM RNN to classify the sequence of vectors.
- *
- * This class aimed to make a general purpose text classifier tool.
- * A common use case would be to tune it for the text Sentiment classification task.
- *
- * The Glove Vectors can be downloaded from https://nlp.stanford.edu/projects/glove/
- *
- * <br/>
- */
-public class GloveRNNTextClassifier {
- private static final Logger LOG = LoggerFactory.getLogger(GloveRNNTextClassifier.class);
-
- private MultiLayerNetwork model;
- private GlobalVectors gloves;
- private DataReader trainSet;
- private DataReader validSet;
-
- private Args args;
-
- public GloveRNNTextClassifier(Args args) throws IOException {
- this.init(args);
- }
-
- public static class Args {
-
- @Option(name="-batchSize", depends = {"-trainDir"},
- usage = "Number of examples in minibatch. Applicable for training only.")
- int batchSize = 128;
-
- @Option(name="-nEpochs", depends = {"-trainDir"},
- usage = "Number of epochs (i.e. full passes over the training data) to train on." +
- " Applicable for training only.")
- int nEpochs = 2;
-
- @Option(name="-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
- int maxSeqLen = 256; //Truncate text with length (# words) greater than this
-
- @Option(name="-vocabSize", usage = "Vocabulary Size.")
- int vocabSize = 20000; //vocabulary size
-
- @Option(name="-nRNNUnits", depends = {"-trainDir"},
- usage = "Number of RNN cells to use. Applicable for training only.")
- int nRNNUnits = 128;
-
- @Option(name="-lr", aliases = "-learnRate", usage = "Learning Rate." +
- " Adjust it when the scores bounce to NaN or Infinity.")
- double learningRate = 2e-3;
-
- @Option(name="-glovesPath", required = true, usage = "Path to GloVe vectors file." +
- " Download and unzip from https://nlp.stanford.edu/projects/glove/")
- String glovesPath = null;
-
- @Option(name="-modelPath", required = true, usage = "Path to model file. " +
- "This will be used for serializing the model after the training phase." +
- "and also the model will be restored from here for prediction")
- String modelPath = null;
-
- @Option(name="-trainDir", usage = "Path to train data directory. Optional." +
- " Setting this value will take the system to training mode. ")
- String trainDir = null;
-
- @Option(name="-validDir", depends = {"-trainDir"}, usage = "Path to validation data directory. Optional." +
- " Applicable only when -trainDir is set.")
- String validDir = null;
-
- @Option(name="-labels", required = true, handler = StringArrayOptionHandler.class,
- usage = "Names of targets or labels separated by spaces. " +
- "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
- "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
- "applicable. \n Example -labels pos neg")
- List<String> labels = null;
-
- @Option(name="-files", handler = StringArrayOptionHandler.class,
- usage = "File paths (separated by space) to predict using the model.")
- List<String> filePaths = null;
-
- @Override
- public String toString() {
- return "Args{" +
- "batchSize=" + batchSize +
- ", nEpochs=" + nEpochs +
- ", maxSeqLen=" + maxSeqLen +
- ", vocabSize=" + vocabSize +
- ", learningRate=" + learningRate +
- ", nRNNUnits=" + nRNNUnits +
- ", glovesPath='" + glovesPath + '\'' +
- ", modelPath='" + modelPath + '\'' +
- ", trainDir='" + trainDir + '\'' +
- ", validDir='" + validDir + '\'' +
- ", labels=" + labels +
- '}';
- }
- }
-
-
- public void init(Args args) throws IOException {
- this.args = args;
- try (InputStream stream = new FileInputStream(args.glovesPath)) {
- this.gloves = new GlobalVectors(stream, args.vocabSize);
- }
-
- if (args.trainDir != null) {
- LOG.info("Training data from {}", args.trainDir);
- this.trainSet = new DataReader(args.trainDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
- if (args.validDir != null) {
- LOG.info("Validation data from {}", args.validDir);
- this.validSet = new DataReader(args.validDir, args.labels, this.gloves, args.batchSize, args.maxSeqLen);
- }
- //create model
- this.model = this.createModel();
- // ready for training
- } else {
- //restore model
- LOG.info("Training data not set => Going to restore model from {}", args.modelPath);
- this.model = ModelSerializer.restoreMultiLayerNetwork(args.modelPath);
- //ready for prediction
- }
- }
-
- private MultiLayerNetwork createModel(){
- int totalOutcomes = this.trainSet.totalOutcomes();
- assert totalOutcomes >= 2;
- LOG.info("Number of classes " + totalOutcomes);
-
- //TODO: the below network params should be configurable from CLI or settings file
- //Set up network configuration
- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
- .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
- .rmsDecay(0.9)
- .regularization(true).l2(1e-5)
- .weightInit(WeightInit.XAVIER)
- .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
- .gradientNormalizationThreshold(1.0)
- .learningRate(args.learningRate)
- .list()
- .layer(0, new GravesLSTM.Builder()
- .nIn(gloves.getVectorSize())
- .nOut(args.nRNNUnits)
- .activation(Activation.RELU).build())
- .layer(1, new RnnOutputLayer.Builder()
- .nIn(args.nRNNUnits)
- .nOut(totalOutcomes)
- .activation(Activation.SOFTMAX)
- .lossFunction(LossFunctions.LossFunction.MCXENT)
- .build())
- .pretrain(false)
- .backprop(true)
- .build();
-
- MultiLayerNetwork net = new MultiLayerNetwork(conf);
- net.init();
- net.setListeners(new ScoreIterationListener(1));
- return net;
- }
-
- public void train(){
- train(args.nEpochs, this.trainSet, this.validSet);
- }
-
- /**
- * Trains model
- * @param nEpochs number of epochs (i.e. iterations over the training dataset)
- * @param train training data set
- * @param validation validation data set for evaluation after each epoch.
- * Setting this to null will skip the evaluation
- */
- public void train(int nEpochs, DataReader train, DataReader validation){
- assert model != null;
- assert train != null;
- LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
- train.totalExamples(), validation == null ? null : validation.totalExamples());
- for (int i = 0; i < nEpochs; i++) {
- model.fit(train);
- train.reset();
- LOG.info("Epoch {} complete", i);
-
- if (validation != null) {
- LOG.info("Starting evaluation");
- //Run evaluation. This is on 25k reviews, so can take some time
- Evaluation evaluation = new Evaluation();
- while (validation.hasNext()) {
- DataSet t = validation.next();
- INDArray features = t.getFeatureMatrix();
- INDArray labels = t.getLabels();
- INDArray inMask = t.getFeaturesMaskArray();
- INDArray outMask = t.getLabelsMaskArray();
- INDArray predicted = model.output(features, false, inMask, outMask);
- evaluation.evalTimeSeries(labels, predicted, outMask);
- }
- validation.reset();
- LOG.info(evaluation.stats());
- }
- }
- }
-
- /**
- * Saves the model to specified path
- * @param path model path
- * @throws IOException
- */
- public void saveModel(String path) throws IOException {
- assert model != null;
- LOG.info("Saving the model at {}", path);
- ModelSerializer.writeModel(model, path, true);
- }
-
- /**
- * Predicts class probability of the text based on the state of model
- * @param text text to be classified
- * @return array of doubles, indices associated with indices of <code>this.args.labels</code>
- */
- public double[] predict(String text){
-
- INDArray seqFeatures = this.gloves.embed(text, this.args.maxSeqLen);
- INDArray networkOutput = this.model.output(seqFeatures);
- int timeSeriesLength = networkOutput.size(2);
- INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
- NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
-
- double[] probs = new double[args.labels.size()];
- for (int i = 0; i < args.labels.size(); i++) {
- probs[i] = probsAtLastWord.getDouble(i);
- }
- return probs;
- }
-
- /**
- * <pre>
- * # Download pre trained Glo-ves (this is a large file)
- * wget http://nlp.stanford.edu/data/glove.6B.zip
- * unzip glove.6B.zip -d glove.6B
- *
- * # Download dataset
- * wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
- * tar xzf aclImdb_v1.tar.gz
- *
- * mvn compile exec:java
- * -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.GloveRNNTextClassifier
- * -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
- * -labels pos neg -modelPath imdb-sentimodel.dat
- * -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
- *
- * </pre>
- *
- */
- public static void main(String[] argss) throws CmdLineException, IOException {
- Args args = new Args();
- CmdLineParser parser = new CmdLineParser(args);
- try {
- parser.parseArgument(argss);
- } catch (CmdLineException e) {
- System.out.println(e.getMessage());
- e.getParser().printUsage(System.out);
- System.exit(1);
- }
- GloveRNNTextClassifier classifier = new GloveRNNTextClassifier(args);
- byte numOps = 0;
- if (classifier.trainSet != null) {
- numOps++;
- classifier.train();
- classifier.saveModel(args.modelPath);
- }
-
- if (args.filePaths != null && !args.filePaths.isEmpty()) {
- numOps++;
- System.out.println("Labels:" + args.labels);
- for (String filePath: args.filePaths) {
- File file = new File(filePath);
- String text = FileUtils.readFileToString(file);
- double[] probs = classifier.predict(text);
- System.out.println(">>" + filePath);
- System.out.println("Probabilities:" + Arrays.toString(probs));
- }
- }
-
- if (numOps == 0) {
- System.out.println("Provide -trainDir to train a model, -files to classify files");
- System.exit(2);
- }
- }
-}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
new file mode 100644
index 0000000..53bf530
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCat.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.tools.dl;
+
+import opennlp.tools.doccat.DocumentCategorizer;
+import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WhitespaceTokenizer;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.NotImplementedException;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+
+/**
+ * An implementation of {@link DocumentCategorizer} using Neural Networks.
+ * This class provides prediction functionality from the model of {@link NeuralDocCatTrainer}.
+ *
+ */
+public class NeuralDocCat implements DocumentCategorizer {
+
+ private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCat.class);
+
+ private NeuralDocCatModel model;
+
+ public NeuralDocCat(NeuralDocCatModel model) {
+ this.model = model;
+ }
+
+ @Override
+ public double[] categorize(String[] tokens) {
+ return categorize(tokens, Collections.emptyMap());
+ }
+
+ @Override
+ public double[] categorize(String[] text, Map<String, Object> extraInformation) {
+ INDArray seqFeatures = this.model.getGloves().embed(text, this.model.getMaxSeqLen());
+
+ INDArray networkOutput = this.model.getNetwork().output(seqFeatures);
+ int timeSeriesLength = networkOutput.size(2);
+ INDArray probsAtLastWord = networkOutput.get(NDArrayIndex.point(0),
+ NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
+
+ int nLabels = this.model.getLabels().size();
+ double[] probs = new double[nLabels];
+ for (int i = 0; i < nLabels; i++) {
+ probs[i] = probsAtLastWord.getDouble(i);
+ }
+ return probs;
+ }
+
+ @Override
+ public String getBestCategory(double[] outcome) {
+ int maxIdx = 0;
+ double maxProb = outcome[0];
+ for (int i = 1; i < outcome.length; i++) {
+ if (outcome[i] > maxProb) {
+ maxIdx = i;
+ maxProb = outcome[i];
+ }
+ }
+ return model.getLabels().get(maxIdx);
+ }
+
+ @Override
+ public int getIndex(String category) {
+ return model.getLabels().indexOf(category);
+ }
+
+ @Override
+ public String getCategory(int index) {
+ return model.getLabels().get(index);
+ }
+
+ @Override
+ public int getNumberOfCategories() {
+ return model.getLabels().size();
+ }
+
+
+ @Override
+ public String getAllResults(double[] results) {
+ throw new NotImplementedException("Not implemented");
+ }
+
+ @Override
+ public Map<String, Double> scoreMap(String[] text) {
+ double[] scores = categorize(text);
+ Map<String, Double> result = new HashMap<>();
+ for (int i = 0; i < scores.length; i++) {
+ result.put(model.getLabels().get(i), scores[i]);
+
+ }
+ return result;
+ }
+
+ @Override
+ public SortedMap<Double, Set<String>> sortedScoreMap(String[] text) {
+ throw new NotImplementedException("Not implemented");
+ }
+
+ @Override
+ @Deprecated
+ public double[] categorize(String documentText) {
+ throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+ }
+
+ @Override
+ @Deprecated
+ public Map<String, Double> scoreMap(String text) {
+ throw new UnsupportedOperationException("Use the other scoreMap(..) method that accepts tokenized text");
+ }
+
+ @Override
+ @Deprecated
+ public SortedMap<Double, Set<String>> sortedScoreMap(String text) {
+ throw new UnsupportedOperationException("Use the other sortedScoreMap(..) method that accepts tokenized text");
+ }
+ @Override
+ @Deprecated
+ public double[] categorize(String documentText, Map<String, Object> extraInformation) {
+ throw new UnsupportedOperationException("Use the other categorize(..) method that accepts tokenized text");
+ }
+
+
+ public static void main(String[] argss) throws CmdLineException, IOException {
+ class Args {
+
+ @Option(name = "-model", required = true, usage = "Path to NeuralDocCatModel stored file")
+ String modelPath;
+
+ @Option(name = "-files", required = true, usage = "One or more document paths whose category is " +
+ "to be predicted by the model")
+ List<File> files;
+ }
+
+ Args args = new Args();
+ CmdLineParser parser = new CmdLineParser(args);
+ try {
+ parser.parseArgument(argss);
+ } catch (CmdLineException e) {
+ System.out.println(e.getMessage());
+ e.getParser().printUsage(System.out);
+ System.exit(1);
+ }
+
+ NeuralDocCatModel model = NeuralDocCatModel.loadModel(args.modelPath);
+ NeuralDocCat classifier = new NeuralDocCat(model);
+
+ System.out.println("Labels:" + model.getLabels());
+ Tokenizer tokenizer = WhitespaceTokenizer.INSTANCE;
+
+ for (File file: args.files) {
+ String text = FileUtils.readFileToString(file);
+ String[] tokens = tokenizer.tokenize(text.toLowerCase());
+ double[] probs = classifier.categorize(tokens);
+ System.out.println(">>" + file);
+ System.out.println("Probabilities:" + Arrays.toString(probs));
+ }
+
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
new file mode 100644
index 0000000..f1b6247
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatModel.java
@@ -0,0 +1,179 @@
+package opennlp.tools.dl;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.*;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipInputStream;
+import java.util.zip.ZipOutputStream;
+
+/**
+ * This class is a wrapper for DL4J's {@link MultiLayerNetwork}, and {@link GlobalVectors}
+ * that provides features to serialize and deserialize necessary data to a zip file.
+ *
+ * This cane be used by a Neural Trainer tool to serialize the network and a predictor tool to restore the same network
+ * with the weights.
+ *
+ * <br/>
+ ** @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatModel {
+
+ public static final int VERSION = 1;
+ public static final String MODEL_NAME = NeuralDocCatModel.class.getName();
+ public static final String MANIFEST = "model.mf";
+ public static final String NETWORK = "network.json";
+ public static final String WEIGHTS = "weights.bin";
+ public static final String GLOVES = "gloves.tsv";
+ public static final String LABELS = "labels";
+ public static final String MAX_SEQ_LEN = "maxSeqLen";
+
+ private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatModel.class);
+
+ private final MultiLayerNetwork network;
+ private final GlobalVectors gloves;
+ private final Properties manifest;
+ private final List<String> labels;
+ private final int maxSeqLen;
+
+ /**
+ *
+ * @param stream Input stream of a Zip File
+ * @throws IOException
+ */
+ public NeuralDocCatModel(InputStream stream) throws IOException {
+ ZipInputStream zipIn = new ZipInputStream(stream);
+
+ Properties manifest = null;
+ MultiLayerNetwork model = null;
+ INDArray params = null;
+ GlobalVectors gloves = null;
+ ZipEntry entry;
+ while ((entry = zipIn.getNextEntry()) != null) {
+ String name = entry.getName();
+ switch (name) {
+ case MANIFEST:
+ manifest = new Properties();
+ manifest.load(zipIn);
+ break;
+ case NETWORK:
+ String json = IOUtils.toString(new UnclosableInputStream(zipIn));
+ model = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
+ break;
+ case WEIGHTS:
+ params = Nd4j.read(new DataInputStream(new UnclosableInputStream(zipIn)));
+ break;
+ case GLOVES:
+ gloves = new GlobalVectors(new UnclosableInputStream(zipIn));
+ break;
+ default:
+ LOG.warn("Unexpected entry in the zip : {}", name);
+ }
+ }
+
+ assert model != null;
+ assert manifest != null;
+ model.init(params, false);
+ this.network = model;
+ this.manifest = manifest;
+ this.gloves = gloves;
+
+ assert manifest.containsKey(LABELS);
+ String[] labels = manifest.getProperty(LABELS).split(",");
+ this.labels = Collections.unmodifiableList(Arrays.asList(labels));
+
+ assert manifest.containsKey(MAX_SEQ_LEN);
+ this.maxSeqLen = Integer.parseInt(manifest.getProperty(MAX_SEQ_LEN));
+
+ }
+
+ /**
+ *
+ * @param network any compatible multi layer neural network
+ * @param vectors Global vectors
+ * @param labels list of labels
+ * @param maxSeqLen max sequence length
+ */
+ public NeuralDocCatModel(MultiLayerNetwork network, GlobalVectors vectors, List<String> labels, int maxSeqLen) {
+ this.network = network;
+ this.gloves = vectors;
+ this.manifest = new Properties();
+ this.manifest.setProperty(LABELS, StringUtils.join(labels, ","));
+ this.manifest.setProperty(MAX_SEQ_LEN, maxSeqLen + "");
+ this.labels = Collections.unmodifiableList(labels);
+ this.maxSeqLen = maxSeqLen;
+ }
+
+ public MultiLayerNetwork getNetwork() {
+ return network;
+ }
+
+ public GlobalVectors getGloves() {
+ return gloves;
+ }
+
+ public List<String> getLabels() {
+ return labels;
+ }
+
+ public int getMaxSeqLen() {
+ return this.maxSeqLen;
+ }
+
+ /**
+ * Zips the current state of the model and writes it stream
+ * @param stream stream to write
+ * @throws IOException
+ */
+ public void saveModel(OutputStream stream) throws IOException {
+ try (ZipOutputStream zipOut = new ZipOutputStream(new BufferedOutputStream(stream))) {
+ // Write out manifest
+ zipOut.putNextEntry(new ZipEntry(MANIFEST));
+
+ String comments = "Created-By:" + System.getenv("USER") + " at " + new Date().toString()
+ + "\nModel-Version: " + VERSION
+ + "\nModel-Schema:" + MODEL_NAME;
+
+ manifest.store(zipOut, comments);
+ zipOut.closeEntry();
+
+ // Write out the network
+ zipOut.putNextEntry(new ZipEntry(NETWORK));
+ byte[] jModel = network.getLayerWiseConfigurations().toJson().getBytes();
+ zipOut.write(jModel);
+ zipOut.closeEntry();
+
+ //Write out the network coefficients
+ zipOut.putNextEntry(new ZipEntry(WEIGHTS));
+ Nd4j.write(network.params(), new DataOutputStream(zipOut));
+ zipOut.closeEntry();
+
+ // Write out vectors
+ zipOut.putNextEntry(new ZipEntry(GLOVES));
+ gloves.writeOut(zipOut, false);
+ zipOut.closeEntry();
+
+ zipOut.finish();
+ }
+ }
+
+ /**
+ * creates a model from file on the local file system
+ * @param modelPath path to model file
+ * @return an instance of this class
+ * @throws IOException
+ */
+ public static NeuralDocCatModel loadModel(String modelPath) throws IOException {
+ try (InputStream modelStream = new FileInputStream(modelPath)) {
+ return new NeuralDocCatModel(modelStream);
+ }
+ }
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
new file mode 100644
index 0000000..7fe11ff
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/NeuralDocCatTrainer.java
@@ -0,0 +1,258 @@
+package opennlp.tools.dl;
+
+import org.deeplearning4j.eval.Evaluation;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.kohsuke.args4j.spi.StringArrayOptionHandler;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.util.List;
+
+
+/**
+ * This class provides functionality to construct and train neural networks that can be used for
+ * {@link opennlp.tools.doccat.DocumentCategorizer}
+ *
+ * @see NeuralDocCat
+ * @see NeuralDocCatModel
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class NeuralDocCatTrainer {
+
+ public static class Args {
+
+ @Option(name = "-batchSize", usage = "Number of examples in minibatch")
+ int batchSize = 128;
+
+ @Option(name = "-nEpochs", usage = "Number of epochs (i.e. full passes over the training data) to train on." +
+ " Applicable for training only.")
+ int nEpochs = 2;
+
+ @Option(name = "-maxSeqLen", usage = "Max Sequence Length. Sequences longer than this will be truncated")
+ int maxSeqLen = 256; //Truncate text with length (# words) greater than this
+
+ @Option(name = "-vocabSize", usage = "Vocabulary Size.")
+ int vocabSize = 20000; //vocabulary size
+
+ @Option(name = "-nRNNUnits", usage = "Number of RNN cells to use.")
+ int nRNNUnits = 128;
+
+ @Option(name = "-lr", aliases = "-learnRate", usage = "Learning Rate." +
+ " Adjust it when the scores bounce to NaN or Infinity.")
+ double learningRate = 2e-3;
+
+ @Option(name = "-glovesPath", required = true, usage = "Path to GloVe vectors file." +
+ " Download and unzip from https://nlp.stanford.edu/projects/glove/")
+ String glovesPath = null;
+
+ @Option(name = "-modelPath", required = true, usage = "Path to model file. " +
+ "This will be used for serializing the model after the training phase." )
+ String modelPath = null;
+
+ @Option(name = "-trainDir", required = true, usage = "Path to train data directory." +
+ " Setting this value will take the system to training mode. ")
+ String trainDir = null;
+
+ @Option(name = "-validDir", usage = "Path to validation data directory. Optional.")
+ String validDir = null;
+
+ @Option(name = "-labels", required = true, handler = StringArrayOptionHandler.class,
+ usage = "Names of targets or labels separated by spaces. " +
+ "The order of labels matters. Make sure to use the same sequence for training and predicting. " +
+ "Also, these names should match subdirectory names of -trainDir and -validDir when those are " +
+ "applicable. \n Example -labels pos neg")
+ List<String> labels = null;
+
+ @Override
+ public String toString() {
+ return "Args{" +
+ "batchSize=" + batchSize +
+ ", nEpochs=" + nEpochs +
+ ", maxSeqLen=" + maxSeqLen +
+ ", vocabSize=" + vocabSize +
+ ", learningRate=" + learningRate +
+ ", nRNNUnits=" + nRNNUnits +
+ ", glovesPath='" + glovesPath + '\'' +
+ ", modelPath='" + modelPath + '\'' +
+ ", trainDir='" + trainDir + '\'' +
+ ", validDir='" + validDir + '\'' +
+ ", labels=" + labels +
+ '}';
+ }
+ }
+
+ private static final Logger LOG = LoggerFactory.getLogger(NeuralDocCatTrainer.class);
+
+ private NeuralDocCatModel model;
+ private Args args;
+ private DataReader trainSet;
+ private DataReader validSet;
+
+
+ public NeuralDocCatTrainer(Args args) throws IOException {
+ this.args = args;
+ GlobalVectors gloves;
+ MultiLayerNetwork network;
+
+ try (InputStream stream = new FileInputStream(args.glovesPath)) {
+ gloves = new GlobalVectors(stream, args.vocabSize);
+ }
+
+ LOG.info("Training data from {}", args.trainDir);
+ this.trainSet = new DataReader(args.trainDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+ if (args.validDir != null) {
+ LOG.info("Validation data from {}", args.validDir);
+ this.validSet = new DataReader(args.validDir, args.labels, gloves, args.batchSize, args.maxSeqLen);
+ }
+
+ //create network
+ network = this.createNetwork(gloves.getVectorSize());
+ this.model = new NeuralDocCatModel(network, gloves, args.labels, args.maxSeqLen);
+ }
+
+ public MultiLayerNetwork createNetwork(int vectorSize) {
+ int totalOutcomes = this.trainSet.totalOutcomes();
+ assert totalOutcomes >= 2;
+ LOG.info("Number of classes " + totalOutcomes);
+
+ //TODO: the below network params should be configurable from CLI or settings file
+ //Set up network configuration
+ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ .updater(Updater.RMSPROP) // ADAM .adamMeanDecay(0.9).adamVarDecay(0.999)
+ .rmsDecay(0.9)
+ .regularization(true).l2(1e-5)
+ .weightInit(WeightInit.XAVIER)
+ .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
+ .gradientNormalizationThreshold(1.0)
+ .learningRate(args.learningRate)
+ .list()
+ .layer(0, new GravesLSTM.Builder()
+ .nIn(vectorSize)
+ .nOut(args.nRNNUnits)
+ .activation(Activation.RELU).build())
+ .layer(1, new RnnOutputLayer.Builder()
+ .nIn(args.nRNNUnits)
+ .nOut(totalOutcomes)
+ .activation(Activation.SOFTMAX)
+ .lossFunction(LossFunctions.LossFunction.MCXENT)
+ .build())
+ .pretrain(false)
+ .backprop(true)
+ .build();
+
+ MultiLayerNetwork net = new MultiLayerNetwork(conf);
+ net.init();
+ net.setListeners(new ScoreIterationListener(1));
+ return net;
+ }
+
+ public void train() {
+ train(args.nEpochs, this.trainSet, this.validSet);
+ }
+
+ /**
+ * Trains model
+ *
+ * @param nEpochs number of epochs (i.e. iterations over the training dataset)
+ * @param train training data set
+ * @param validation validation data set for evaluation after each epoch.
+ * Setting this to null will skip the evaluation
+ */
+ public void train(int nEpochs, DataReader train, DataReader validation) {
+ assert model != null;
+ assert train != null;
+ LOG.info("Starting training...\nTotal epochs={}, Training Size={}, Validation Size={}", nEpochs,
+ train.totalExamples(), validation == null ? null : validation.totalExamples());
+ for (int i = 0; i < nEpochs; i++) {
+ model.getNetwork().fit(train);
+ train.reset();
+ LOG.info("Epoch {} complete", i);
+
+ if (validation != null) {
+ LOG.info("Starting evaluation");
+ //Run evaluation. This is on 25k reviews, so can take some time
+ Evaluation evaluation = new Evaluation();
+ while (validation.hasNext()) {
+ DataSet t = validation.next();
+ INDArray features = t.getFeatureMatrix();
+ INDArray labels = t.getLabels();
+ INDArray inMask = t.getFeaturesMaskArray();
+ INDArray outMask = t.getLabelsMaskArray();
+ INDArray predicted = this.model.getNetwork().output(features, false, inMask, outMask);
+ evaluation.evalTimeSeries(labels, predicted, outMask);
+ }
+ validation.reset();
+ LOG.info(evaluation.stats());
+ }
+ }
+ }
+
+ /**
+ * Saves the model to specified path
+ *
+ * @param path model path
+ * @throws IOException
+ */
+ public void saveModel(String path) throws IOException {
+ assert model != null;
+ LOG.info("Saving the model at {}", path);
+ try (OutputStream stream = new FileOutputStream(path)) {
+ model.saveModel(stream);
+ }
+ }
+
+ /**
+ * <pre>
+ * # Download pre trained Glo-ves (this is a large file)
+ * wget http://nlp.stanford.edu/data/glove.6B.zip
+ * unzip glove.6B.zip -d glove.6B
+ *
+ * # Download dataset
+ * wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+ * tar xzf aclImdb_v1.tar.gz
+ *
+ * mvn compile exec:java
+ * -Dexec.mainClass=edu.usc.irds.sentiment.analysis.dl.NeuralDocCat
+ * -Dexec.args="-glovesPath $HOME/work/datasets/glove.6B/glove.6B.100d.txt
+ * -labels pos neg -modelPath imdb-sentiment-neural-model.zip
+ * -trainDir=$HOME/work/datasets/aclImdb/train -lr 0.001"
+ *
+ * </pre>
+ */
+ public static void main(String[] argss) throws CmdLineException, IOException {
+ argss = ("-trainDir /Users/tg/work/datasets/aclImdb-tiny/train " +
+ "-glovesPath /Users/tg/work/datasets/glove.6B/glove.6B.50d.txt " +
+ "-labels pos neg " +
+ "-lr 0.1 -nEpochs 1 " +
+ "-modelPath neural-doc-cat.zip").split(" ");
+ Args args = new Args();
+ CmdLineParser parser = new CmdLineParser(args);
+ try {
+ parser.parseArgument(argss);
+ } catch (CmdLineException e) {
+ System.out.println(e.getMessage());
+ e.getParser().printUsage(System.out);
+ System.exit(1);
+ }
+ NeuralDocCatTrainer classifier = new NeuralDocCatTrainer(args);
+ classifier.train();
+ classifier.saveModel(args.modelPath);
+ }
+
+}
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
new file mode 100644
index 0000000..701fc48
--- /dev/null
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/UnclosableInputStream.java
@@ -0,0 +1,56 @@
+package opennlp.tools.dl;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.Reader;
+import java.io.Writer;
+
+/**
+ * This class offers a wrapper for {@link InputStream};
+ * The only sole purpose of this wrapper is to bypass the close calls that are usually
+ * propagated from the readers.
+ * A use case of this wrapper is for reading multiple files from the {@link java.util.zip.ZipInputStream},
+ * especially because the tools like {@link org.apache.commons.io.IOUtils#copy(Reader, Writer)}
+ * and {@link org.nd4j.linalg.factory.Nd4j#read(InputStream)} automatically close the input stream.
+ *
+ * Note:
+ * 1. this tool ignores the call to {@link #close()} method
+ * 2. Remember to call {@link #forceClose()} when the stream when the inner stream needs to be closed
+ * 3. This wrapper doesn't hold any resources. If you close the innerStream, you can safely ignore closing this wrapper
+ *
+ * @author Thamme Gowda (thammegowda@apache.org)
+ */
+public class UnclosableInputStream extends InputStream {
+
+ private InputStream innerStream;
+
+ public UnclosableInputStream(InputStream stream){
+ this.innerStream = stream;
+ }
+
+ @Override
+ public int read() throws IOException {
+ return innerStream.read();
+ }
+
+ /**
+ * NOP - Does not close the stream - intentional
+ * @throws IOException
+ */
+ @Override
+ public void close() throws IOException {
+ // intentionally ignored;
+ // Use forceClose() when needed to close
+ }
+
+ /**
+ * Closes the stream
+ * @throws IOException
+ */
+ public void forceClose() throws IOException {
+ if (innerStream != null) {
+ innerStream.close();
+ innerStream = null;
+ }
+ }
+}