Merge pull request #10 from thygesen/tfnerpoc

added tensorflow NER prediction PoC
diff --git a/tf-ner-poc/pom.xml b/tf-ner-poc/pom.xml
new file mode 100644
index 0000000..71e9620
--- /dev/null
+++ b/tf-ner-poc/pom.xml
@@ -0,0 +1,49 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+    <modelVersion>4.0.0</modelVersion>
+
+    <groupId>org.apache.opennlp</groupId>
+    <artifactId>tf-ner</artifactId>
+    <version>1.0-SNAPSHOT</version>
+
+    <properties>
+        <tensorflow.version>1.7.0</tensorflow.version>
+    </properties>
+
+    <dependencies>
+        <dependency>
+            <groupId>org.tensorflow</groupId>
+            <artifactId>tensorflow</artifactId>
+            <version>${tensorflow.version}</version>
+        </dependency>
+
+        <dependency>
+            <groupId>org.apache.opennlp</groupId>
+            <artifactId>opennlp-tools</artifactId>
+            <version>[1.8.4,)</version>
+        </dependency>
+
+        <dependency>
+            <groupId>junit</groupId>
+            <artifactId>junit</artifactId>
+            <version>4.12</version>
+            <scope>test</scope>
+        </dependency>
+    </dependencies>
+
+    <build>
+        <plugins>
+            <plugin>
+                <groupId>org.apache.maven.plugins</groupId>
+                <artifactId>maven-compiler-plugin</artifactId>
+                <configuration>
+                    <source>1.8</source>
+                    <target>1.8</target>
+                </configuration>
+            </plugin>
+        </plugins>
+    </build>
+
+</project>
\ No newline at end of file
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/FeedDictionary.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/FeedDictionary.java
new file mode 100644
index 0000000..39265b7
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/FeedDictionary.java
@@ -0,0 +1,144 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import org.tensorflow.Tensor;
+
+import java.util.Arrays;
+
+public class FeedDictionary {
+
+  static int PAD_VALUE = 0;
+
+
+  private final float dropout;
+  private final int[][][] charIds;
+  private final int[][] wordLengths;
+  private final int[][] wordIds;
+  private final int[] sentenceLengths;
+  private final int maxSentenceLength;
+  private final int maxCharLength;
+  private final int numberOfSentences;
+
+  public float getDropout() {
+    return dropout;
+  }
+
+  public int[][][] getCharIds() {
+    return charIds;
+  }
+
+  public int[][] getWordLengths() {
+    return wordLengths;
+  }
+
+  public int[][] getWordIds() {
+    return wordIds;
+  }
+
+  public int[] getSentenceLengths() {
+    return sentenceLengths;
+  }
+
+  public int getMaxSentenceLength() {
+    return maxSentenceLength;
+  }
+
+  public int getMaxCharLength() {
+    return maxCharLength;
+  }
+
+  public int getNumberOfSentences() {
+    return numberOfSentences;
+  }
+
+  public Tensor<Integer> getSentenceLengthsTensor() {
+    return Tensor.create(sentenceLengths, Integer.class);
+  }
+
+  public Tensor<Float> getDropoutTensor() {
+    return Tensor.create(dropout, Float.class);
+  }
+
+  public Tensor<Integer> getCharIdsTensor() {
+    return Tensor.create(charIds, Integer.class);
+  }
+
+  public Tensor<Integer> getWordLengthsTensor() {
+    return Tensor.create(wordLengths, Integer.class);
+  }
+
+  public Tensor<Integer> getWordIdsTensor() {
+    return Tensor.create(wordIds, Integer.class);
+  }
+
+  private FeedDictionary(final float dropout,
+                         final int[][][] charIds,
+                         final int[][] wordLengths,
+                         final int[][] wordIds,
+                         final int[] sentenceLengths,
+                         final int maxSentenceLength,
+                         final int maxCharLength,
+                         final int numberOfSentences) {
+
+    this.dropout = dropout;
+    this.charIds = charIds;
+    this.wordLengths = wordLengths;
+    this.wordIds = wordIds;
+    this.sentenceLengths = sentenceLengths;
+    this.maxSentenceLength = maxSentenceLength;
+    this.maxCharLength = maxCharLength;
+    this.numberOfSentences = numberOfSentences;
+
+  }
+
+  // multi sentences
+  public static FeedDictionary create(TokenIds sentences) {
+
+    int numberOfSentences = sentences.getWordIds().length;
+
+    int[][][] charIds = new int[numberOfSentences][][];
+    int[][] wordLengths = new int[numberOfSentences][];
+
+    int maxSentenceLength = Arrays.stream(sentences.getWordIds()).map(s -> s.length).reduce(Integer::max).get();
+    Padded paddedSentences = padArrays(sentences.getWordIds(), maxSentenceLength);
+    int[][] wordIds = paddedSentences.ids;
+    int[] sentenceLengths = paddedSentences.lengths;
+
+    int maxCharLength = Arrays.stream(sentences.getCharIds()).flatMap(s -> Arrays.stream(s).map(c -> c.length)).reduce(Integer::max).get();
+    for (int i=0; i < numberOfSentences; i++) {
+      Padded paddedWords = padArrays(sentences.getCharIds()[i], maxCharLength);
+      charIds[i] = paddedWords.ids;
+      wordLengths[i] = paddedWords.lengths;
+    }
+
+    return new FeedDictionary(1.0f, charIds, wordLengths, wordIds, sentenceLengths, maxSentenceLength, maxCharLength, numberOfSentences);
+
+  }
+
+  private static Padded padArrays(int[][] ids, int length) {
+
+    int[][] paddedIds = new int[ids.length][length];
+    int[] lengths = new int[ids.length];
+
+    for (int i = 0; i < ids.length; i++) {
+      int[] src = ids[i];
+      int[] dest = new int[length];
+      System.arraycopy(src, 0, dest, 0, src.length);
+      if (src.length < length)
+        Arrays.fill(dest, src.length, length, PAD_VALUE);
+      paddedIds[i] = dest;
+      lengths[i] = src.length;
+    }
+
+    return new Padded(paddedIds, lengths);
+
+  }
+
+  private static class Padded {
+    Padded(int[][] ids, int[] lengths) {
+      this.ids = ids;
+      this.lengths = lengths;
+    }
+    private int[][] ids;
+    private int[] lengths;
+  }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/IndexTagger.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/IndexTagger.java
new file mode 100644
index 0000000..4dc92b7
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/IndexTagger.java
@@ -0,0 +1,41 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+public class IndexTagger {
+
+  private Map<Integer, String> idx2Tag = new HashMap<>();
+
+  public IndexTagger(InputStream vocabTags) throws IOException {
+    try(BufferedReader in = new BufferedReader(
+            new InputStreamReader(
+                    vocabTags, "UTF8"))) {
+      String tag;
+      int idx = 0;
+      while ((tag = in.readLine()) != null) {
+        idx2Tag.put(idx, tag);
+        idx += 1;
+      }
+    }
+
+  }
+
+  public String getTag(Integer idx) {
+    return idx2Tag.get(idx);
+  }
+
+  public Map<Integer, String> getIdx2Tag() {
+    return Collections.unmodifiableMap(idx2Tag);
+  }
+
+  public int getNumberOfTags() {
+    return idx2Tag.size();
+  }
+
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/PredictionConfiguration.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/PredictionConfiguration.java
new file mode 100644
index 0000000..581c7a1
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/PredictionConfiguration.java
@@ -0,0 +1,40 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+public class PredictionConfiguration {
+
+  private String vocabWords;
+  private String vocabChars;
+  private String vocabTags;
+  private String savedModel;
+
+  public PredictionConfiguration(String vocabWords, String vocabChars, String vocabTags, String savedModel) {
+    this.vocabWords = vocabWords;
+    this.vocabChars = vocabChars;
+    this.vocabTags = vocabTags;
+    this.savedModel = savedModel;
+  }
+
+  public String getVocabWords() {
+    return vocabWords;
+  }
+
+  public String getVocabChars() {
+    return vocabChars;
+  }
+
+  public String getVocabTags() {
+    return vocabTags;
+  }
+
+  public String getSavedModel() {
+    return savedModel;
+  }
+
+  public InputStream getVocabWordsInputStream() throws IOException{
+    return new FileInputStream(getVocabWords());
+  }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
new file mode 100644
index 0000000..47de303
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
@@ -0,0 +1,75 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.zip.GZIPInputStream;
+
+public class SequenceTagging implements AutoCloseable {
+
+  private final SavedModelBundle model;
+  private final Session session;
+  private final WordIndexer wordIndexer;
+  private final IndexTagger indexTagger;
+
+  public SequenceTagging(PredictionConfiguration config) throws IOException {
+    this.model = SavedModelBundle.load(config.getSavedModel(), "serve");
+    this.session = model.session();
+    this.wordIndexer = new WordIndexer(new GZIPInputStream(new FileInputStream(config.getVocabWords())),
+            new GZIPInputStream(new FileInputStream(config.getVocabChars())));
+    this.indexTagger = new IndexTagger(new GZIPInputStream(new FileInputStream(config.getVocabTags())));
+  }
+
+  public String[] predict(String[] sentence) {
+    TokenIds tokenIds = wordIndexer.toTokenIds(sentence);
+    return predict(tokenIds)[0];
+  }
+
+  public String[][] predict(String[][] sentences) {
+    TokenIds tokenIds = wordIndexer.toTokenIds(sentences);
+    return predict(tokenIds);
+  }
+
+  private String[][] predict(TokenIds tokenIds) {
+    FeedDictionary fd = FeedDictionary.create(tokenIds);
+
+    List<Tensor<?>> run = session.runner()
+            .feed("char_ids:0", fd.getCharIdsTensor())
+            .feed("dropout:0", fd.getDropoutTensor())
+            .feed("sequence_lengths:0", fd.getSentenceLengthsTensor())
+            .feed("word_ids:0", fd.getWordIdsTensor())
+            .feed("word_lengths:0", fd.getWordLengthsTensor())
+            .fetch("proj/logits", 0)
+            .fetch("trans_params", 0).run();
+
+
+    float[][][] logits = new float[fd.getNumberOfSentences()][fd.getMaxSentenceLength()][indexTagger.getNumberOfTags()];
+    run.get(0).copyTo(logits);
+
+    float[][] trans_params = new float[indexTagger.getNumberOfTags()][indexTagger.getNumberOfTags()];
+    run.get(1).copyTo(trans_params);
+
+    //# iterate over the sentences because no batching in vitervi_decode
+    //for logit, sequence_length in zip(logits, sequence_lengths):
+    //List<List<Integer>> viterbi_sequences = new ArrayList<>();
+
+    String[][] returnValue = new String[fd.getNumberOfSentences()][];
+    for (int i=0; i < logits.length; i++) {
+      //logit = logit[:sequence_length] # keep only the valid steps
+      float[][] logit = Arrays.copyOf(logits[i], fd.getSentenceLengths()[i]);
+      returnValue[i] = Viterbi.decode(logit, trans_params).stream().map(indexTagger::getTag).toArray(String[]::new);
+    }
+
+    return returnValue;
+  }
+
+  @Override
+  public void close() throws Exception {
+    session.close();
+  }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/TokenIds.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/TokenIds.java
new file mode 100644
index 0000000..45f8e81
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/TokenIds.java
@@ -0,0 +1,20 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+public final class TokenIds {
+
+  private final int[][][] charIds;
+  private final int[][] wordIds;
+
+  public TokenIds(int[][][] charIds, int[][] wordIds) {
+    this.charIds = charIds;
+    this.wordIds = wordIds;
+  }
+
+  public int[][][] getCharIds() {
+    return charIds;
+  }
+
+  public int[][] getWordIds() {
+    return wordIds;
+  }
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java
new file mode 100644
index 0000000..7ea016f
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java
@@ -0,0 +1,157 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+public class Viterbi {
+
+  /*
+  """Viterbi the highest scoring sequence of tags outside of TensorFlow.
+  This should only be used at test time.
+  Args:
+    score: A [seq_len, num_tags] matrix of unary potentials.
+    transition_params: A [num_tags, num_tags] matrix of binary potentials.
+  Returns:
+    viterbi: A [seq_len] list of integers containing the highest scoring tag
+        indices.
+    viterbi_score: A float containing the score for the Viterbi sequence.
+  """
+   */
+
+  private static float[][] zeros_like(float[][] matrix) {
+    float[][] returnValue = new float[matrix.length][matrix[0].length];
+    for (int i=0; i<matrix.length; i++)
+      Arrays.fill(returnValue[i], 0.0f);
+    return returnValue;
+  }
+
+  private static int[][] zeros_like(int[] shape) {
+    int[][] returnValue = new int[shape[0]][shape[1]];
+    for (int i=0; i<shape[0]; i++)
+      Arrays.fill(returnValue[i], 0);
+    return returnValue;
+  }
+
+  private static int[] shape(float[][] var) {
+    return new int[] {var.length, var[0].length};
+  }
+
+  private static float[][] expand_dims_axis_one_plus_array(float[] array, float[][] plus) {
+    int[] plus_shape = shape(plus);
+    if (plus_shape[0] != array.length)
+      throw new RuntimeException("Not same shape");
+    float[][] returnValue = new float[plus_shape[0]][plus_shape[1]];
+    for (int i=0; i < array.length; i++) {
+      for (int j=0; j < plus_shape[1]; j++) {
+        returnValue[i][j] = array[i] + plus[i][j];
+      }
+    }
+    return returnValue;
+  }
+
+  private static float[] max_columnwise(float[][] array) {
+    float[] returnValue = new float[array[0].length];
+    for (int col=0; col < array[0].length; col++) {
+      returnValue[col] = Float.MIN_VALUE;
+      for (int row=0; row < array.length; row++) {
+        returnValue[col] = Float.max(returnValue[col],array[row][col]);
+      }
+    }
+
+    return returnValue;
+  }
+
+  private static float max(float[] array) {
+    float returnValue = Float.MIN_VALUE;
+    for (int col=0; col < array.length; col++) {
+        returnValue = Float.max(returnValue, array[col]);
+    }
+    return returnValue;
+  }
+
+  private static int[] argmax_columnwise(float[][] array) {
+    int[] returnValue = new int[array[0].length];
+    for (int col=0; col < array[0].length; col++) {
+      float max = Float.MIN_VALUE;
+      int idx = -1;
+      for (int row=0; row < array.length; row++) {
+        if (Float.compare(array[row][col], max) > 0) {
+          max = array[row][col];
+          idx = row;
+        }
+      }
+      returnValue[col] = idx;
+    }
+    return returnValue;
+  }
+
+  private static int argmax(float[] array) {
+    int returnValue = -1;
+    float max = Float.MIN_VALUE;
+    for (int col=0; col < array.length; col++) {
+      if (Float.compare(array[col], max) > 0) {
+        max = array[col];
+        returnValue = col;
+      }
+    }
+    return returnValue;
+  }
+
+  public static float[] plus(float[] a, float[] b) {
+    if (a.length == b.length) {
+      float[] returnValue = new float[a.length];
+      for(int i = 0; i < a.length; ++i) {
+        returnValue[i] = Float.sum(a[i], b[i]);
+      }
+      return returnValue;
+    } else {
+      throw new IllegalArgumentException("Arrays doesn't have same shape.");
+    }
+  }
+
+  public static List<Integer> decode(float[][] score, float[][] transition_params) {
+    // trellis = np.zeros_like(score)
+    float[][] trellis = zeros_like(score);
+
+    // backpointers = np.zeros_like(score, dtype=np.int32)
+    int[][] backpointers = zeros_like(shape(score));
+
+    // trellis[0] = score[0]
+    trellis[0] = score[0];
+
+    // for t in range(1, score.shape[0]):
+    for (int t=1; t < score.length; t++) {
+      //v = np.expand_dims(trellis[t - 1], 1) + transition_params
+      float[][] v = expand_dims_axis_one_plus_array(trellis[t - 1], transition_params);
+
+      //trellis[t] = score[t] + np.max(v, 0)
+      trellis[t] = plus(score[t], max_columnwise(v));
+
+      //backpointers[t] = np.argmax(v, 0)
+      backpointers[t] = argmax_columnwise(v);
+    }
+
+    // viterbi = [np.argmax(trellis[-1])]
+    List<Integer> viterbi = new ArrayList();
+    viterbi.add(argmax(trellis[trellis.length - 1]));
+
+    // for bp in reversed(backpointers[1:]):
+    for (int i=backpointers.length - 1; i >= 1; i--) {
+      // viterbi.append(bp[viterbi[-1]])
+      int[] bp = backpointers[i];
+      viterbi.add(bp[viterbi.get(viterbi.size() - 1)]);
+    }
+
+    // viterbi.reverse()
+    Collections.reverse(viterbi);
+
+    // viterbi_score = np.max(trellis[-1])
+    // float viterbi_score = max(trellis[trellis.length - 1])) not used!
+
+    // return viterbi, viterbi_score
+    return viterbi;
+  }
+
+}
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/WordIndexer.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/WordIndexer.java
new file mode 100644
index 0000000..f6a80e1
--- /dev/null
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/WordIndexer.java
@@ -0,0 +1,149 @@
+package org.apache.opennlp.tf.guillaumegenthial;
+
+import opennlp.tools.util.StringUtil;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+public class WordIndexer {
+
+  private final Map<Character, Integer> char2idx;
+  private final Map<String, Integer> word2idx;
+
+  public static String UNK = "$UNK$";
+  public static String NUM = "$NUM$";
+  public static String NONE = "O";
+
+  //private boolean useChars = true;
+  private boolean lowerCase = true;
+  private boolean allowUnk = true;
+
+  private Pattern digitPattern = Pattern.compile("\\d+(,\\d+)*(\\.\\d+)?");
+
+  public WordIndexer(InputStream vocabWords, InputStream vocabChars) throws IOException {
+    this.word2idx = new HashMap<>();
+    try(BufferedReader in = new BufferedReader(new InputStreamReader(vocabWords, "UTF8"))) {
+      String word;
+      int idx = 0;
+      while ((word = in.readLine()) != null) {
+        word2idx.put(word, idx);
+        idx += 1;
+      }
+    }
+
+    this.char2idx = new HashMap<>();
+    try(BufferedReader in = new BufferedReader(new InputStreamReader(vocabChars, "UTF8"))) {
+      String ch;
+      int idx = 0;
+      while ((ch = in.readLine()) != null) {
+        char2idx.put(ch.charAt(0), idx);
+        idx += 1;
+      }
+    }
+
+  }
+
+  public TokenIds toTokenIds(String[] tokens) {
+    String[][] sentences = new String[1][];
+    sentences[0] = tokens;
+    return toTokenIds(sentences);
+  }
+
+  public TokenIds toTokenIds(String[][] sentences) {
+    int[][][] charIds = new int[sentences.length][][];
+    int[][] wordIds = new int[sentences.length][];
+
+    for (int i = 0; i < sentences.length; i++) {
+      String[] sentenceWords = sentences[i];
+
+      int[][] sentcharIds = new int[sentenceWords.length][];
+      int[] sentwordIds = new int[sentenceWords.length];
+
+      for (int j=0; j < sentenceWords.length; j++) {
+        Ids ids = apply(sentenceWords[j]);
+
+        sentcharIds[j] = Arrays.copyOf(ids.getChars(), ids.getChars().length);
+        sentwordIds[j] = ids.getWord();
+      }
+
+      charIds[i] = sentcharIds;
+      wordIds[i] = sentwordIds;
+    }
+
+    return new TokenIds(charIds, wordIds);
+  }
+
+
+  private Ids apply(String word) {
+    // 0. get chars of words
+    int[] charIds = new int[word.length()];
+    int skipChars = 0;
+    for (int i = 0; i < word.length(); i++) {
+      char ch = word.charAt(i);
+      // ignore chars out of vocabulary
+      if (char2idx.containsKey(ch))
+        charIds[i - skipChars] = char2idx.get(ch);
+      else
+        skipChars += 1;
+    }
+
+    // 1. preprocess word
+    if (lowerCase) {
+      word = StringUtil.toLowerCase(word);
+    }
+    if (digitPattern.matcher(word).find())
+      word = NUM;
+
+    // 2. get id of word
+    Integer wordId;
+    if (word2idx.containsKey(word)) {
+      wordId = word2idx.get(word);
+    } else {
+      if (allowUnk)
+        wordId = word2idx.get(UNK);
+      else
+        throw new RuntimeException("Unknown word '" + word + "' is not allowed.");
+    }
+
+    // 3. return tuple char ids, word id
+    Ids tokenIds = new Ids();
+    if (skipChars > 0) {
+      tokenIds.setChars(Arrays.copyOf(charIds, charIds.length - skipChars));
+    } else {
+      tokenIds.setChars(charIds);
+    }
+    tokenIds.setWord(wordId);
+
+    return tokenIds;
+  }
+
+
+  public class Ids {
+
+    private int[] chars;
+    private int word;
+
+    public int[] getChars() {
+      return chars;
+    }
+
+    public void setChars(int[] chars) {
+      this.chars = chars;
+    }
+
+    public int getWord() {
+      return word;
+    }
+
+    public void setWord(int word) {
+      this.word = word;
+    }
+  }
+
+}
diff --git a/tf-ner-poc/src/test/java/com/apache/opennlp/tf/PredictTest.java b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/PredictTest.java
new file mode 100644
index 0000000..8de57c9
--- /dev/null
+++ b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/PredictTest.java
@@ -0,0 +1,32 @@
+package com.apache.opennlp.tf;
+
+import org.apache.opennlp.tf.guillaumegenthial.PredictionConfiguration;
+import org.apache.opennlp.tf.guillaumegenthial.SequenceTagging;
+
+import java.io.IOException;
+
+public class PredictTest {
+
+  public static void main(String[] args) throws IOException {
+
+    // Load model takes a String path!!
+    String model = PredictTest.class.getResource("/savedmodel").getPath();
+    // can be changed to File or InputStream
+    String words = PredictTest.class.getResource("/words.txt.gz").getPath();
+    String chars = PredictTest.class.getResource("/chars.txt.gz").getPath();
+    String tags = PredictTest.class.getResource("/tags.txt.gz").getPath();
+
+
+    PredictionConfiguration config = new PredictionConfiguration(words, chars, tags, model);
+
+    SequenceTagging tagger = new SequenceTagging(config);
+
+    String[] tokens = "Stormy Cars ' friend says she also plans to sue Michael Cohen .".split("\\s+");
+    String[] pred = tagger.predict(tokens);
+
+    for (int i=0; i<tokens.length; i++) {
+      System.out.print(tokens[i] + "/" + pred[i] + " ");
+    }
+    System.out.println();
+  }
+}
diff --git a/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/FeedDictionaryTest.java b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/FeedDictionaryTest.java
new file mode 100644
index 0000000..a2c26a8
--- /dev/null
+++ b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/FeedDictionaryTest.java
@@ -0,0 +1,44 @@
+package com.apache.opennlp.tf.guillaumegenthial;
+
+import org.apache.opennlp.tf.guillaumegenthial.TokenIds;
+import org.apache.opennlp.tf.guillaumegenthial.WordIndexer;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.zip.GZIPInputStream;
+
+public class FeedDictionaryTest {
+
+  private static TokenIds oneSentence;
+  private static TokenIds twoSentences;
+
+  @BeforeClass
+  public static void beforeClass() {
+
+    WordIndexer indexer;
+    try {
+      InputStream words = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/words.txt"));
+      InputStream chars = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/chars.txt"));
+      indexer = new WordIndexer(words, chars);
+    } catch (Exception ex) {
+      indexer = null;
+    }
+    Assume.assumeNotNull(indexer);
+
+    String text1 = "Stormy Cars ' friend says she also plans to sue Michael Cohen .";
+    oneSentence = indexer.toTokenIds(text1.split("\\s+"));
+    Assume.assumeNotNull(oneSentence);
+
+    String[] text2 = new String[] {"I wish I was born in Copenhagen Denmark",
+            "Donald Trump died on his way to Tivoli Gardens in Denmark ."};
+    List<String[]> collect = Arrays.stream(text2).map(s -> s.split("\\s+")).collect(Collectors.toList());
+    twoSentences = indexer.toTokenIds(collect.toArray(new String[2][]));
+    Assume.assumeNotNull(twoSentences);
+
+  }
+
+}
diff --git a/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/WordIndexerTest.java b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/WordIndexerTest.java
new file mode 100644
index 0000000..b5812fd
--- /dev/null
+++ b/tf-ner-poc/src/test/java/com/apache/opennlp/tf/guillaumegenthial/WordIndexerTest.java
@@ -0,0 +1,132 @@
+package com.apache.opennlp.tf.guillaumegenthial;
+
+import org.apache.opennlp.tf.guillaumegenthial.TokenIds;
+import org.apache.opennlp.tf.guillaumegenthial.WordIndexer;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.zip.GZIPInputStream;
+
+public class WordIndexerTest {
+
+  private static WordIndexer indexer;
+
+  @BeforeClass
+  public static void beforeClass() {
+    try {
+      InputStream words = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/words.txt"));
+      InputStream chars = new GZIPInputStream(WordIndexerTest.class.getResourceAsStream("/chars.txt"));
+      indexer = new WordIndexer(words, chars);
+    } catch (Exception ex) {
+      indexer = null;
+    }
+    Assume.assumeNotNull(indexer);
+  }
+
+  @Test
+  public void testToTokenIds_OneSentence() {
+
+    String text = "Stormy Cars ' friend says she also plans to sue Michael Cohen .";
+
+    TokenIds ids = indexer.toTokenIds(text.split("\\s+"));
+
+    Assert.assertEquals("Expect 13 tokenIds", 13, ids.getWordIds()[0].length);
+
+    Assert.assertArrayEquals(new int[] {7, 30, 34, 80, 42, 3}, ids.getCharIds()[0][0]);
+    Assert.assertArrayEquals(new int[] {51, 41, 80, 54}, ids.getCharIds()[0][1]);
+    Assert.assertArrayEquals(new int[] {64}, ids.getCharIds()[0][2]);
+    Assert.assertArrayEquals(new int[] {47, 80, 82, 83, 31, 23}, ids.getCharIds()[0][3]);
+    Assert.assertArrayEquals(new int[] {54, 41, 3, 54}, ids.getCharIds()[0][4]);
+    Assert.assertArrayEquals(new int[] {54, 76, 83}, ids.getCharIds()[0][5]);
+    Assert.assertArrayEquals(new int[] {41, 55, 54, 34}, ids.getCharIds()[0][6]);
+    Assert.assertArrayEquals(new int[] {46, 55, 41, 31, 54}, ids.getCharIds()[0][7]);
+    Assert.assertArrayEquals(new int[] {30, 34}, ids.getCharIds()[0][8]);
+    Assert.assertArrayEquals(new int[] {54, 50, 83}, ids.getCharIds()[0][9]);
+    Assert.assertArrayEquals(new int[] {39, 82, 20, 76, 41, 83, 55}, ids.getCharIds()[0][10]);
+    Assert.assertArrayEquals(new int[] {51, 34, 76, 83, 31}, ids.getCharIds()[0][11]);
+    Assert.assertArrayEquals(new int[] {65}, ids.getCharIds()[0][12]);
+
+    Assert.assertEquals(2720, ids.getWordIds()[0][0]);
+    Assert.assertEquals(15275,ids.getWordIds()[0][1]);
+    Assert.assertEquals(3256, ids.getWordIds()[0][2]);
+    Assert.assertEquals(11348, ids.getWordIds()[0][3]);
+    Assert.assertEquals(21054, ids.getWordIds()[0][4]);
+    Assert.assertEquals(18337, ids.getWordIds()[0][5]);
+    Assert.assertEquals(7885, ids.getWordIds()[0][6]);
+    Assert.assertEquals(7697, ids.getWordIds()[0][7]);
+    Assert.assertEquals(16601, ids.getWordIds()[0][8]);
+    Assert.assertEquals(2720, ids.getWordIds()[0][9]);
+    Assert.assertEquals(17408, ids.getWordIds()[0][10]);
+    Assert.assertEquals(11541, ids.getWordIds()[0][11]);
+    Assert.assertEquals(2684, ids.getWordIds()[0][12]);
+
+  }
+
+  @Test
+  public void testToTokenIds_TwoSentences() {
+
+    String[] text = new String[] {"I wish I was born in Copenhagen Denmark",
+            "Donald Trump died on his way to Tivoli Gardens in Denmark ."};
+
+    List<String[]> collect = Arrays.stream(text).map(s -> s.split("\\s+")).collect(Collectors.toList());
+
+    TokenIds ids = indexer.toTokenIds(collect.toArray(new String[2][]));
+
+    Assert.assertEquals(8, ids.getWordIds()[0].length);
+    Assert.assertEquals(12, ids.getWordIds()[1].length);
+
+    Assert.assertArrayEquals(new int[] {4}, ids.getCharIds()[0][0]);
+    Assert.assertArrayEquals(new int[] {6, 82, 54, 76}, ids.getCharIds()[0][1]);
+    Assert.assertArrayEquals(new int[] {4}, ids.getCharIds()[0][2]);
+    Assert.assertArrayEquals(new int[] {6, 41, 54}, ids.getCharIds()[0][3]);
+    Assert.assertArrayEquals(new int[] {59, 34, 80, 31}, ids.getCharIds()[0][4]);
+    Assert.assertArrayEquals(new int[] {82, 31}, ids.getCharIds()[0][5]);
+    Assert.assertArrayEquals(new int[] {51, 34, 46, 83, 31, 76, 41, 28, 83, 31}, ids.getCharIds()[0][6]);
+    Assert.assertArrayEquals(new int[] {36, 83, 31, 42, 41, 80, 49}, ids.getCharIds()[0][7]);
+
+    Assert.assertArrayEquals(new int[] {36, 34, 31, 41, 55, 23}, ids.getCharIds()[1][0]);
+    Assert.assertArrayEquals(new int[] {52, 80, 50, 42, 46}, ids.getCharIds()[1][1]);
+    Assert.assertArrayEquals(new int[] {23, 82, 83, 23}, ids.getCharIds()[1][2]);
+    Assert.assertArrayEquals(new int[] {34, 31}, ids.getCharIds()[1][3]);
+    Assert.assertArrayEquals(new int[] {76, 82, 54}, ids.getCharIds()[1][4]);
+    Assert.assertArrayEquals(new int[] {6, 41, 3}, ids.getCharIds()[1][5]);
+    Assert.assertArrayEquals(new int[] {30, 34}, ids.getCharIds()[1][6]);
+    Assert.assertArrayEquals(new int[] {52, 82, 11, 34, 55, 82}, ids.getCharIds()[1][7]);
+    Assert.assertArrayEquals(new int[] {74, 41, 80, 23, 83, 31, 54}, ids.getCharIds()[1][8]);
+    Assert.assertArrayEquals(new int[] {82, 31}, ids.getCharIds()[1][9]);
+    Assert.assertArrayEquals(new int[] {36, 83, 31, 42, 41, 80, 49}, ids.getCharIds()[1][10]);
+    Assert.assertArrayEquals(new int[] {65}, ids.getCharIds()[1][11]);
+
+    Assert.assertEquals(21931, ids.getWordIds()[0][0]);
+    Assert.assertEquals(20473, ids.getWordIds()[0][1]);
+    Assert.assertEquals(21931, ids.getWordIds()[0][2]);
+    Assert.assertEquals(5477, ids.getWordIds()[0][3]);
+    Assert.assertEquals(11538, ids.getWordIds()[0][4]);
+    Assert.assertEquals(21341, ids.getWordIds()[0][5]);
+    Assert.assertEquals(14024, ids.getWordIds()[0][6]);
+    Assert.assertEquals(7420, ids.getWordIds()[0][7]);
+
+    Assert.assertEquals(12492, ids.getWordIds()[1][0]);
+    Assert.assertEquals(2720, ids.getWordIds()[1][1]);
+    Assert.assertEquals(9476, ids.getWordIds()[1][2]);
+    Assert.assertEquals(16537, ids.getWordIds()[1][3]);
+    Assert.assertEquals(18966, ids.getWordIds()[1][4]);
+    Assert.assertEquals(21088, ids.getWordIds()[1][5]);
+    Assert.assertEquals(16601, ids.getWordIds()[1][6]);
+    Assert.assertEquals(2720, ids.getWordIds()[1][7]);
+    Assert.assertEquals(2720, ids.getWordIds()[1][8]);
+    Assert.assertEquals(21341, ids.getWordIds()[1][9]);
+    Assert.assertEquals(7420, ids.getWordIds()[1][10]);
+    Assert.assertEquals(2684, ids.getWordIds()[1][11]);
+
+  }
+
+
+
+}