Merge pull request #10 from thygesen/tfnerpoc
added tensorflow NER prediction PoC
diff --git a/tf-ner-poc/pom.xml b/tf-ner-poc/pom.xml
new file mode 100644
index 0000000..71e9620
--- /dev/null
+++ b/tf-ner-poc/pom.xml
@@ -0,0 +1,49 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <groupId>org.apache.opennlp</groupId>
+ <artifactId>tf-ner</artifactId>
+ <version>1.0-SNAPSHOT</version>
+
+ <properties>
+ <tensorflow.version>1.7.0</tensorflow.version>
+ </properties>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ <version>${tensorflow.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.opennlp</groupId>
+ <artifactId>opennlp-tools</artifactId>
+ <version>[1.8.4,)</version>
+ </dependency>
+
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>4.12</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <source>1.8</source>
+ <target>1.8</target>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+
+</project>
\ No newline at end of file
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/FeedDictionary.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/FeedDictionary.java
new file mode 100644
index 0000000..39265b7
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/FeedDictionary.java
@@ -0,0 +1,144 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import org.tensorflow.Tensor;
+
+import java.util.Arrays;
+
+public class FeedDictionary {
+
+ static int PAD_VALUE = 0;
+
+
+ private final float dropout;
+ private final int[][][] charIds;
+ private final int[][] wordLengths;
+ private final int[][] wordIds;
+ private final int[] sentenceLengths;
+ private final int maxSentenceLength;
+ private final int maxCharLength;
+ private final int numberOfSentences;
+
+ public float getDropout() {
+ return dropout;
+ }
+
+ public int[][][] getCharIds() {
+ return charIds;
+ }
+
+ public int[][] getWordLengths() {
+ return wordLengths;
+ }
+
+ public int[][] getWordIds() {
+ return wordIds;
+ }
+
+ public int[] getSentenceLengths() {
+ return sentenceLengths;
+ }
+
+ public int getMaxSentenceLength() {
+ return maxSentenceLength;
+ }
+
+ public int getMaxCharLength() {
+ return maxCharLength;
+ }
+
+ public int getNumberOfSentences() {
+ return numberOfSentences;
+ }
+
+ public Tensor<Integer> getSentenceLengthsTensor() {
+ return Tensor.create(sentenceLengths, Integer.class);
+ }
+
+ public Tensor<Float> getDropoutTensor() {
+ return Tensor.create(dropout, Float.class);
+ }
+
+ public Tensor<Integer> getCharIdsTensor() {
+ return Tensor.create(charIds, Integer.class);
+ }
+
+ public Tensor<Integer> getWordLengthsTensor() {
+ return Tensor.create(wordLengths, Integer.class);
+ }
+
+ public Tensor<Integer> getWordIdsTensor() {
+ return Tensor.create(wordIds, Integer.class);
+ }
+
+ private FeedDictionary(final float dropout,
+ final int[][][] charIds,
+ final int[][] wordLengths,
+ final int[][] wordIds,
+ final int[] sentenceLengths,
+ final int maxSentenceLength,
+ final int maxCharLength,
+ final int numberOfSentences) {
+
+ this.dropout = dropout;
+ this.charIds = charIds;
+ this.wordLengths = wordLengths;
+ this.wordIds = wordIds;
+ this.sentenceLengths = sentenceLengths;
+ this.maxSentenceLength = maxSentenceLength;
+ this.maxCharLength = maxCharLength;
+ this.numberOfSentences = numberOfSentences;
+
+ }
+
+ // multi sentences
+ public static FeedDictionary create(TokenIds sentences) {
+
+ int numberOfSentences = sentences.getWordIds().length;
+
+ int[][][] charIds = new int[numberOfSentences][][];
+ int[][] wordLengths = new int[numberOfSentences][];
+
+ int maxSentenceLength = Arrays.stream(sentences.getWordIds()).map(s -> s.length).reduce(Integer::max).get();
+ Padded paddedSentences = padArrays(sentences.getWordIds(), maxSentenceLength);
+ int[][] wordIds = paddedSentences.ids;
+ int[] sentenceLengths = paddedSentences.lengths;
+
+ int maxCharLength = Arrays.stream(sentences.getCharIds()).flatMap(s -> Arrays.stream(s).map(c -> c.length)).reduce(Integer::max).get();
+ for (int i=0; i < numberOfSentences; i++) {
+ Padded paddedWords = padArrays(sentences.getCharIds()[i], maxCharLength);
+ charIds[i] = paddedWords.ids;
+ wordLengths[i] = paddedWords.lengths;
+ }
+
+ return new FeedDictionary(1.0f, charIds, wordLengths, wordIds, sentenceLengths, maxSentenceLength, maxCharLength, numberOfSentences);
+
+ }
+
+ private static Padded padArrays(int[][] ids, int length) {
+
+ int[][] paddedIds = new int[ids.length][length];
+ int[] lengths = new int[ids.length];
+
+ for (int i = 0; i < ids.length; i++) {
+ int[] src = ids[i];
+ int[] dest = new int[length];
+ System.arraycopy(src, 0, dest, 0, src.length);
+ if (src.length < length)
+ Arrays.fill(dest, src.length, length, PAD_VALUE);
+ paddedIds[i] = dest;
+ lengths[i] = src.length;
+ }
+
+ return new Padded(paddedIds, lengths);
+
+ }
+
+ private static class Padded {
+ Padded(int[][] ids, int[] lengths) {
+ this.ids = ids;
+ this.lengths = lengths;
+ }
+ private int[][] ids;
+ private int[] lengths;
+ }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/IndexTagger.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/IndexTagger.java
new file mode 100644
index 0000000..4dc92b7
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/IndexTagger.java
@@ -0,0 +1,41 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+public class IndexTagger {
+
+ private Map<Integer, String> idx2Tag = new HashMap<>();
+
+ public IndexTagger(InputStream vocabTags) throws IOException {
+ try(BufferedReader in = new BufferedReader(
+ new InputStreamReader(
+ vocabTags, "UTF8"))) {
+ String tag;
+ int idx = 0;
+ while ((tag = in.readLine()) != null) {
+ idx2Tag.put(idx, tag);
+ idx += 1;
+ }
+ }
+
+ }
+
+ public String getTag(Integer idx) {
+ return idx2Tag.get(idx);
+ }
+
+ public Map<Integer, String> getIdx2Tag() {
+ return Collections.unmodifiableMap(idx2Tag);
+ }
+
+ public int getNumberOfTags() {
+ return idx2Tag.size();
+ }
+
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/PredictionConfiguration.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/PredictionConfiguration.java
new file mode 100644
index 0000000..581c7a1
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/PredictionConfiguration.java
@@ -0,0 +1,40 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+public class PredictionConfiguration {
+
+ private String vocabWords;
+ private String vocabChars;
+ private String vocabTags;
+ private String savedModel;
+
+ public PredictionConfiguration(String vocabWords, String vocabChars, String vocabTags, String savedModel) {
+ this.vocabWords = vocabWords;
+ this.vocabChars = vocabChars;
+ this.vocabTags = vocabTags;
+ this.savedModel = savedModel;
+ }
+
+ public String getVocabWords() {
+ return vocabWords;
+ }
+
+ public String getVocabChars() {
+ return vocabChars;
+ }
+
+ public String getVocabTags() {
+ return vocabTags;
+ }
+
+ public String getSavedModel() {
+ return savedModel;
+ }
+
+ public InputStream getVocabWordsInputStream() throws IOException{
+ return new FileInputStream(getVocabWords());
+ }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
new file mode 100644
index 0000000..47de303
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
@@ -0,0 +1,75 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.zip.GZIPInputStream;
+
+public class SequenceTagging implements AutoCloseable {
+
+ private final SavedModelBundle model;
+ private final Session session;
+ private final WordIndexer wordIndexer;
+ private final IndexTagger indexTagger;
+
+ public SequenceTagging(PredictionConfiguration config) throws IOException {
+ this.model = SavedModelBundle.load(config.getSavedModel(), "serve");
+ this.session = model.session();
+ this.wordIndexer = new WordIndexer(new GZIPInputStream(new FileInputStream(config.getVocabWords())),
+ new GZIPInputStream(new FileInputStream(config.getVocabChars())));
+ this.indexTagger = new IndexTagger(new GZIPInputStream(new FileInputStream(config.getVocabTags())));
+ }
+
+ public String[] predict(String[] sentence) {
+ TokenIds tokenIds = wordIndexer.toTokenIds(sentence);
+ return predict(tokenIds)[0];
+ }
+
+ public String[][] predict(String[][] sentences) {
+ TokenIds tokenIds = wordIndexer.toTokenIds(sentences);
+ return predict(tokenIds);
+ }
+
+ private String[][] predict(TokenIds tokenIds) {
+ FeedDictionary fd = FeedDictionary.create(tokenIds);
+
+ List<Tensor<?>> run = session.runner()
+ .feed("char_ids:0", fd.getCharIdsTensor())
+ .feed("dropout:0", fd.getDropoutTensor())
+ .feed("sequence_lengths:0", fd.getSentenceLengthsTensor())
+ .feed("word_ids:0", fd.getWordIdsTensor())
+ .feed("word_lengths:0", fd.getWordLengthsTensor())
+ .fetch("proj/logits", 0)
+ .fetch("trans_params", 0).run();
+
+
+ float[][][] logits = new float[fd.getNumberOfSentences()][fd.getMaxSentenceLength()][indexTagger.getNumberOfTags()];
+ run.get(0).copyTo(logits);
+
+ float[][] trans_params = new float[indexTagger.getNumberOfTags()][indexTagger.getNumberOfTags()];
+ run.get(1).copyTo(trans_params);
+
+ //# iterate over the sentences because no batching in vitervi_decode
+ //for logit, sequence_length in zip(logits, sequence_lengths):
+ //List<List<Integer>> viterbi_sequences = new ArrayList<>();
+
+ String[][] returnValue = new String[fd.getNumberOfSentences()][];
+ for (int i=0; i < logits.length; i++) {
+ //logit = logit[:sequence_length] # keep only the valid steps
+ float[][] logit = Arrays.copyOf(logits[i], fd.getSentenceLengths()[i]);
+ returnValue[i] = Viterbi.decode(logit, trans_params).stream().map(indexTagger::getTag).toArray(String[]::new);
+ }
+
+ return returnValue;
+ }
+
+ @Override
+ public void close() throws Exception {
+ session.close();
+ }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/TokenIds.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/TokenIds.java
new file mode 100644
index 0000000..45f8e81
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/TokenIds.java
@@ -0,0 +1,20 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+public final class TokenIds {
+
+ private final int[][][] charIds;
+ private final int[][] wordIds;
+
+ public TokenIds(int[][][] charIds, int[][] wordIds) {
+ this.charIds = charIds;
+ this.wordIds = wordIds;
+ }
+
+ public int[][][] getCharIds() {
+ return charIds;
+ }
+
+ public int[][] getWordIds() {
+ return wordIds;
+ }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java
new file mode 100644
index 0000000..7ea016f
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java
@@ -0,0 +1,157 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+public class Viterbi {
+
+ /*
+ """Viterbi the highest scoring sequence of tags outside of TensorFlow.
+ This should only be used at test time.
+ Args:
+ score: A [seq_len, num_tags] matrix of unary potentials.
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
+ Returns:
+ viterbi: A [seq_len] list of integers containing the highest scoring tag
+ indices.
+ viterbi_score: A float containing the score for the Viterbi sequence.
+ """
+ */
+
+ private static float[][] zeros_like(float[][] matrix) {
+ float[][] returnValue = new float[matrix.length][matrix[0].length];
+ for (int i=0; i<matrix.length; i++)
+ Arrays.fill(returnValue[i], 0.0f);
+ return returnValue;
+ }
+
+ private static int[][] zeros_like(int[] shape) {
+ int[][] returnValue = new int[shape[0]][shape[1]];
+ for (int i=0; i<shape[0]; i++)
+ Arrays.fill(returnValue[i], 0);
+ return returnValue;
+ }
+
+ private static int[] shape(float[][] var) {
+ return new int[] {var.length, var[0].length};
+ }
+
+ private static float[][] expand_dims_axis_one_plus_array(float[] array, float[][] plus) {
+ int[] plus_shape = shape(plus);
+ if (plus_shape[0] != array.length)
+ throw new RuntimeException("Not same shape");
+ float[][] returnValue = new float[plus_shape[0]][plus_shape[1]];
+ for (int i=0; i < array.length; i++) {
+ for (int j=0; j < plus_shape[1]; j++) {
+ returnValue[i][j] = array[i] + plus[i][j];
+ }
+ }
+ return returnValue;
+ }
+
+ private static float[] max_columnwise(float[][] array) {
+ float[] returnValue = new float[array[0].length];
+ for (int col=0; col < array[0].length; col++) {
+ returnValue[col] = Float.MIN_VALUE;
+ for (int row=0; row < array.length; row++) {
+ returnValue[col] = Float.max(returnValue[col],array[row][col]);
+ }
+ }
+
+ return returnValue;
+ }
+
+ private static float max(float[] array) {
+ float returnValue = Float.MIN_VALUE;
+ for (int col=0; col < array.length; col++) {
+ returnValue = Float.max(returnValue, array[col]);
+ }
+ return returnValue;
+ }
+
+ private static int[] argmax_columnwise(float[][] array) {
+ int[] returnValue = new int[array[0].length];
+ for (int col=0; col < array[0].length; col++) {
+ float max = Float.MIN_VALUE;
+ int idx = -1;
+ for (int row=0; row < array.length; row++) {
+ if (Float.compare(array[row][col], max) > 0) {
+ max = array[row][col];
+ idx = row;
+ }
+ }
+ returnValue[col] = idx;
+ }
+ return returnValue;
+ }
+
+ private static int argmax(float[] array) {
+ int returnValue = -1;
+ float max = Float.MIN_VALUE;
+ for (int col=0; col < array.length; col++) {
+ if (Float.compare(array[col], max) > 0) {
+ max = array[col];
+ returnValue = col;
+ }
+ }
+ return returnValue;
+ }
+
+ public static float[] plus(float[] a, float[] b) {
+ if (a.length == b.length) {
+ float[] returnValue = new float[a.length];
+ for(int i = 0; i < a.length; ++i) {
+ returnValue[i] = Float.sum(a[i], b[i]);
+ }
+ return returnValue;
+ } else {
+ throw new IllegalArgumentException("Arrays doesn't have same shape.");
+ }
+ }
+
+ public static List<Integer> decode(float[][] score, float[][] transition_params) {
+ // trellis = np.zeros_like(score)
+ float[][] trellis = zeros_like(score);
+
+ // backpointers = np.zeros_like(score, dtype=np.int32)
+ int[][] backpointers = zeros_like(shape(score));
+
+ // trellis[0] = score[0]
+ trellis[0] = score[0];
+
+ // for t in range(1, score.shape[0]):
+ for (int t=1; t < score.length; t++) {
+ //v = np.expand_dims(trellis[t - 1], 1) + transition_params
+ float[][] v = expand_dims_axis_one_plus_array(trellis[t - 1], transition_params);
+
+ //trellis[t] = score[t] + np.max(v, 0)
+ trellis[t] = plus(score[t], max_columnwise(v));
+
+ //backpointers[t] = np.argmax(v, 0)
+ backpointers[t] = argmax_columnwise(v);
+ }
+
+ // viterbi = [np.argmax(trellis[-1])]
+ List<Integer> viterbi = new ArrayList();
+ viterbi.add(argmax(trellis[trellis.length - 1]));
+
+ // for bp in reversed(backpointers[1:]):
+ for (int i=backpointers.length - 1; i >= 1; i--) {
+ // viterbi.append(bp[viterbi[-1]])
+ int[] bp = backpointers[i];
+ viterbi.add(bp[viterbi.get(viterbi.size() - 1)]);
+ }
+
+ // viterbi.reverse()
+ Collections.reverse(viterbi);
+
+ // viterbi_score = np.max(trellis[-1])
+ // float viterbi_score = max(trellis[trellis.length - 1])) not used!
+
+ // return viterbi, viterbi_score
+ return viterbi;
+ }
+
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/WordIndexer.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/WordIndexer.java
new file mode 100644
index 0000000..f6a80e1
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/WordIndexer.java
@@ -0,0 +1,149 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import opennlp.tools.util.StringUtil;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+public class WordIndexer {
+
+ private final Map<Character, Integer> char2idx;
+ private final Map<String, Integer> word2idx;
+
+ public static String UNK = "$UNK$";
+ public static String NUM = "$NUM$";
+ public static String NONE = "O";
+
+ //private boolean useChars = true;
+ private boolean lowerCase = true;
+ private boolean allowUnk = true;
+
+ private Pattern digitPattern = Pattern.compile("\\d+(,\\d+)*(\\.\\d+)?");
+
+ public WordIndexer(InputStream vocabWords, InputStream vocabChars) throws IOException {
+ this.word2idx = new HashMap<>();
+ try(BufferedReader in = new BufferedReader(new InputStreamReader(vocabWords, "UTF8"))) {
+ String word;
+ int idx = 0;
+ while ((word = in.readLine()) != null) {
+ word2idx.put(word, idx);
+ idx += 1;
+ }
+ }
+
+ this.char2idx = new HashMap<>();
+ try(BufferedReader in = new BufferedReader(new InputStreamReader(vocabChars, "UTF8"))) {
+ String ch;
+ int idx = 0;
+ while ((ch = in.readLine()) != null) {
+ char2idx.put(ch.charAt(0), idx);
+ idx += 1;
+ }
+ }
+
+ }
+
+ public TokenIds toTokenIds(String[] tokens) {
+ String[][] sentences = new String[1][];
+ sentences[0] = tokens;
+ return toTokenIds(sentences);
+ }
+
+ public TokenIds toTokenIds(String[][] sentences) {
+ int[][][] charIds = new int[sentences.length][][];
+ int[][] wordIds = new int[sentences.length][];
+
+ for (int i = 0; i < sentences.length; i++) {
+ String[] sentenceWords = sentences[i];
+
+ int[][] sentcharIds = new int[sentenceWords.length][];
+ int[] sentwordIds = new int[sentenceWords.length];
+
+ for (int j=0; j < sentenceWords.length; j++) {
+ Ids ids = apply(sentenceWords[j]);
+
+ sentcharIds[j] = Arrays.copyOf(ids.getChars(), ids.getChars().length);
+ sentwordIds[j] = ids.getWord();
+ }
+
+ charIds[i] = sentcharIds;
+ wordIds[i] = sentwordIds;
+ }
+
+ return new TokenIds(charIds, wordIds);
+ }
+
+
+ private Ids apply(String word) {
+ // 0. get chars of words
+ int[] charIds = new int[word.length()];
+ int skipChars = 0;
+ for (int i = 0; i < word.length(); i++) {
+ char ch = word.charAt(i);
+ // ignore chars out of vocabulary
+ if (char2idx.containsKey(ch))
+ charIds[i - skipChars] = char2idx.get(ch);
+ else
+ skipChars += 1;
+ }
+
+ // 1. preprocess word
+ if (lowerCase) {
+ word = StringUtil.toLowerCase(word);
+ }
+ if (digitPattern.matcher(word).find())
+ word = NUM;
+
+ // 2. get id of word
+ Integer wordId;
+ if (word2idx.containsKey(word)) {
+ wordId = word2idx.get(word);
+ } else {
+ if (allowUnk)
+ wordId = word2idx.get(UNK);
+ else
+ throw new RuntimeException("Unknown word '" + word + "' is not allowed.");
+ }
+
+ // 3. return tuple char ids, word id
+ Ids tokenIds = new Ids();
+ if (skipChars > 0) {
+ tokenIds.setChars(Arrays.copyOf(charIds, charIds.length - skipChars));
+ } else {
+ tokenIds.setChars(charIds);
+ }
+ tokenIds.setWord(wordId);
+
+ return tokenIds;
+ }
+
+
+ public class Ids {
+
+ private int[] chars;
+ private int word;
+
+ public int[] getChars() {
+ return chars;
+ }
+
+ public void setChars(int[] chars) {
+ this.chars = chars;
+ }
+
+ public int getWord() {
+ return word;
+ }
+
+ public void setWord(int word) {
+ this.word = word;
+ }
+ }
+
+}
diff --git a/tf-ner-poc/src/test/java/com/apache/opennlp/tf/PredictTest.java b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/PredictTest.java
new file mode 100644
index 0000000..8de57c9
--- /dev/null
+++ b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/PredictTest.java
@@ -0,0 +1,32 @@
+package com.apache.opennlp.tf;
+
+import org.apache.opennlp.tf.guillaumegenthial.PredictionConfiguration;
+import org.apache.opennlp.tf.guillaumegenthial.SequenceTagging;
+
+import java.io.IOException;
+
+public class PredictTest {
+
+ public static void main(String[] args) throws IOException {
+
+ // Load model takes a String path!!
+ String model = PredictTest.class.getResource("/savedmodel").getPath();
+ // can be changed to File or InputStream
+ String words = PredictTest.class.getResource("/words.txt.gz").getPath();
+ String chars = PredictTest.class.getResource("/chars.txt.gz").getPath();
+ String tags = PredictTest.class.getResource("/tags.txt.gz").getPath();
+
+
+ PredictionConfiguration config = new PredictionConfiguration(words, chars, tags, model);
+
+ SequenceTagging tagger = new SequenceTagging(config);
+
+ String[] tokens = "Stormy Cars ' friend says she also plans to sue Michael Cohen .".split("\\s+");
+ String[] pred = tagger.predict(tokens);
+
+ for (int i=0; i<tokens.length; i++) {
+ System.out.print(tokens[i] + "/" + pred[i] + " ");
+ }
+ System.out.println();
+ }
+}
diff --git a/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/FeedDictionaryTest.java b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/FeedDictionaryTest.java
new file mode 100644
index 0000000..a2c26a8
--- /dev/null
+++ b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/FeedDictionaryTest.java
@@ -0,0 +1,44 @@
+package com.apache.opennlp.tf.guillaumegenthial;
+
+import org.apache.opennlp.tf.guillaumegenthial.TokenIds;
+import org.apache.opennlp.tf.guillaumegenthial.WordIndexer;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.zip.GZIPInputStream;
+
+public class FeedDictionaryTest {
+
+ private static TokenIds oneSentence;
+ private static TokenIds twoSentences;
+
+ @BeforeClass
+ public static void beforeClass() {
+
+ WordIndexer indexer;
+ try {
+ InputStream words = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/words.txt"));
+ InputStream chars = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/chars.txt"));
+ indexer = new WordIndexer(words, chars);
+ } catch (Exception ex) {
+ indexer = null;
+ }
+ Assume.assumeNotNull(indexer);
+
+ String text1 = "Stormy Cars ' friend says she also plans to sue Michael Cohen .";
+ oneSentence = indexer.toTokenIds(text1.split("\\s+"));
+ Assume.assumeNotNull(oneSentence);
+
+ String[] text2 = new String[] {"I wish I was born in Copenhagen Denmark",
+ "Donald Trump died on his way to Tivoli Gardens in Denmark ."};
+ List<String[]> collect = Arrays.stream(text2).map(s -> s.split("\\s+")).collect(Collectors.toList());
+ twoSentences = indexer.toTokenIds(collect.toArray(new String[2][]));
+ Assume.assumeNotNull(twoSentences);
+
+ }
+
+}
diff --git a/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/WordIndexerTest.java b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/WordIndexerTest.java
new file mode 100644
index 0000000..b5812fd
--- /dev/null
+++ b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/WordIndexerTest.java
@@ -0,0 +1,132 @@
+package com.apache.opennlp.tf.guillaumegenthial;
+
+import org.apache.opennlp.tf.guillaumegenthial.TokenIds;
+import org.apache.opennlp.tf.guillaumegenthial.WordIndexer;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.zip.GZIPInputStream;
+
+public class WordIndexerTest {
+
+ private static WordIndexer indexer;
+
+ @BeforeClass
+ public static void beforeClass() {
+ try {
+ InputStream words = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/words.txt"));
+ InputStream chars = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/chars.txt"));
+ indexer = new WordIndexer(words, chars);
+ } catch (Exception ex) {
+ indexer = null;
+ }
+ Assume.assumeNotNull(indexer);
+ }
+
+ @Test
+ public void testToTokenIds_OneSentence() {
+
+ String text = "Stormy Cars ' friend says she also plans to sue Michael Cohen .";
+
+ TokenIds ids = indexer.toTokenIds(text.split("\\s+"));
+
+ Assert.assertEquals("Expect 13 tokenIds", 13, ids.getWordIds()[0].length);
+
+ Assert.assertArrayEquals(new int[] {7, 30, 34, 80, 42, 3}, ids.getCharIds()[0][0]);
+ Assert.assertArrayEquals(new int[] {51, 41, 80, 54}, ids.getCharIds()[0][1]);
+ Assert.assertArrayEquals(new int[] {64}, ids.getCharIds()[0][2]);
+ Assert.assertArrayEquals(new int[] {47, 80, 82, 83, 31, 23}, ids.getCharIds()[0][3]);
+ Assert.assertArrayEquals(new int[] {54, 41, 3, 54}, ids.getCharIds()[0][4]);
+ Assert.assertArrayEquals(new int[] {54, 76, 83}, ids.getCharIds()[0][5]);
+ Assert.assertArrayEquals(new int[] {41, 55, 54, 34}, ids.getCharIds()[0][6]);
+ Assert.assertArrayEquals(new int[] {46, 55, 41, 31, 54}, ids.getCharIds()[0][7]);
+ Assert.assertArrayEquals(new int[] {30, 34}, ids.getCharIds()[0][8]);
+ Assert.assertArrayEquals(new int[] {54, 50, 83}, ids.getCharIds()[0][9]);
+ Assert.assertArrayEquals(new int[] {39, 82, 20, 76, 41, 83, 55}, ids.getCharIds()[0][10]);
+ Assert.assertArrayEquals(new int[] {51, 34, 76, 83, 31}, ids.getCharIds()[0][11]);
+ Assert.assertArrayEquals(new int[] {65}, ids.getCharIds()[0][12]);
+
+ Assert.assertEquals(2720, ids.getWordIds()[0][0]);
+ Assert.assertEquals(15275,ids.getWordIds()[0][1]);
+ Assert.assertEquals(3256, ids.getWordIds()[0][2]);
+ Assert.assertEquals(11348, ids.getWordIds()[0][3]);
+ Assert.assertEquals(21054, ids.getWordIds()[0][4]);
+ Assert.assertEquals(18337, ids.getWordIds()[0][5]);
+ Assert.assertEquals(7885, ids.getWordIds()[0][6]);
+ Assert.assertEquals(7697, ids.getWordIds()[0][7]);
+ Assert.assertEquals(16601, ids.getWordIds()[0][8]);
+ Assert.assertEquals(2720, ids.getWordIds()[0][9]);
+ Assert.assertEquals(17408, ids.getWordIds()[0][10]);
+ Assert.assertEquals(11541, ids.getWordIds()[0][11]);
+ Assert.assertEquals(2684, ids.getWordIds()[0][12]);
+
+ }
+
+ @Test
+ public void testToTokenIds_TwoSentences() {
+
+ String[] text = new String[] {"I wish I was born in Copenhagen Denmark",
+ "Donald Trump died on his way to Tivoli Gardens in Denmark ."};
+
+ List<String[]> collect = Arrays.stream(text).map(s -> s.split("\\s+")).collect(Collectors.toList());
+
+ TokenIds ids = indexer.toTokenIds(collect.toArray(new String[2][]));
+
+ Assert.assertEquals(8, ids.getWordIds()[0].length);
+ Assert.assertEquals(12, ids.getWordIds()[1].length);
+
+ Assert.assertArrayEquals(new int[] {4}, ids.getCharIds()[0][0]);
+ Assert.assertArrayEquals(new int[] {6, 82, 54, 76}, ids.getCharIds()[0][1]);
+ Assert.assertArrayEquals(new int[] {4}, ids.getCharIds()[0][2]);
+ Assert.assertArrayEquals(new int[] {6, 41, 54}, ids.getCharIds()[0][3]);
+ Assert.assertArrayEquals(new int[] {59, 34, 80, 31}, ids.getCharIds()[0][4]);
+ Assert.assertArrayEquals(new int[] {82, 31}, ids.getCharIds()[0][5]);
+ Assert.assertArrayEquals(new int[] {51, 34, 46, 83, 31, 76, 41, 28, 83, 31}, ids.getCharIds()[0][6]);
+ Assert.assertArrayEquals(new int[] {36, 83, 31, 42, 41, 80, 49}, ids.getCharIds()[0][7]);
+
+ Assert.assertArrayEquals(new int[] {36, 34, 31, 41, 55, 23}, ids.getCharIds()[1][0]);
+ Assert.assertArrayEquals(new int[] {52, 80, 50, 42, 46}, ids.getCharIds()[1][1]);
+ Assert.assertArrayEquals(new int[] {23, 82, 83, 23}, ids.getCharIds()[1][2]);
+ Assert.assertArrayEquals(new int[] {34, 31}, ids.getCharIds()[1][3]);
+ Assert.assertArrayEquals(new int[] {76, 82, 54}, ids.getCharIds()[1][4]);
+ Assert.assertArrayEquals(new int[] {6, 41, 3}, ids.getCharIds()[1][5]);
+ Assert.assertArrayEquals(new int[] {30, 34}, ids.getCharIds()[1][6]);
+ Assert.assertArrayEquals(new int[] {52, 82, 11, 34, 55, 82}, ids.getCharIds()[1][7]);
+ Assert.assertArrayEquals(new int[] {74, 41, 80, 23, 83, 31, 54}, ids.getCharIds()[1][8]);
+ Assert.assertArrayEquals(new int[] {82, 31}, ids.getCharIds()[1][9]);
+ Assert.assertArrayEquals(new int[] {36, 83, 31, 42, 41, 80, 49}, ids.getCharIds()[1][10]);
+ Assert.assertArrayEquals(new int[] {65}, ids.getCharIds()[1][11]);
+
+ Assert.assertEquals(21931, ids.getWordIds()[0][0]);
+ Assert.assertEquals(20473, ids.getWordIds()[0][1]);
+ Assert.assertEquals(21931, ids.getWordIds()[0][2]);
+ Assert.assertEquals(5477, ids.getWordIds()[0][3]);
+ Assert.assertEquals(11538, ids.getWordIds()[0][4]);
+ Assert.assertEquals(21341, ids.getWordIds()[0][5]);
+ Assert.assertEquals(14024, ids.getWordIds()[0][6]);
+ Assert.assertEquals(7420, ids.getWordIds()[0][7]);
+
+ Assert.assertEquals(12492, ids.getWordIds()[1][0]);
+ Assert.assertEquals(2720, ids.getWordIds()[1][1]);
+ Assert.assertEquals(9476, ids.getWordIds()[1][2]);
+ Assert.assertEquals(16537, ids.getWordIds()[1][3]);
+ Assert.assertEquals(18966, ids.getWordIds()[1][4]);
+ Assert.assertEquals(21088, ids.getWordIds()[1][5]);
+ Assert.assertEquals(16601, ids.getWordIds()[1][6]);
+ Assert.assertEquals(2720, ids.getWordIds()[1][7]);
+ Assert.assertEquals(2720, ids.getWordIds()[1][8]);
+ Assert.assertEquals(21341, ids.getWordIds()[1][9]);
+ Assert.assertEquals(7420, ids.getWordIds()[1][10]);
+ Assert.assertEquals(2684, ids.getWordIds()[1][11]);
+
+ }
+
+
+
+}