added files for test
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
index 47de303..f05a09d 100644
--- a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
@@ -54,10 +54,6 @@
float[][] trans_params = new float[indexTagger.getNumberOfTags()][indexTagger.getNumberOfTags()];
run.get(1).copyTo(trans_params);
- //# iterate over the sentences because no batching in vitervi_decode
- //for logit, sequence_length in zip(logits, sequence_lengths):
- //List<List<Integer>> viterbi_sequences = new ArrayList<>();
-
String[][] returnValue = new String[fd.getNumberOfSentences()][];
for (int i=0; i < logits.length; i++) {
//logit = logit[:sequence_length] # keep only the valid steps
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java
index 7ea016f..0942d4c 100644
--- a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/Viterbi.java
@@ -112,45 +112,29 @@
}
public static List<Integer> decode(float[][] score, float[][] transition_params) {
- // trellis = np.zeros_like(score)
+
float[][] trellis = zeros_like(score);
- // backpointers = np.zeros_like(score, dtype=np.int32)
int[][] backpointers = zeros_like(shape(score));
- // trellis[0] = score[0]
trellis[0] = score[0];
- // for t in range(1, score.shape[0]):
for (int t=1; t < score.length; t++) {
- //v = np.expand_dims(trellis[t - 1], 1) + transition_params
float[][] v = expand_dims_axis_one_plus_array(trellis[t - 1], transition_params);
-
- //trellis[t] = score[t] + np.max(v, 0)
trellis[t] = plus(score[t], max_columnwise(v));
-
- //backpointers[t] = np.argmax(v, 0)
backpointers[t] = argmax_columnwise(v);
}
- // viterbi = [np.argmax(trellis[-1])]
List<Integer> viterbi = new ArrayList();
viterbi.add(argmax(trellis[trellis.length - 1]));
- // for bp in reversed(backpointers[1:]):
for (int i=backpointers.length - 1; i >= 1; i--) {
- // viterbi.append(bp[viterbi[-1]])
int[] bp = backpointers[i];
viterbi.add(bp[viterbi.get(viterbi.size() - 1)]);
}
- // viterbi.reverse()
Collections.reverse(viterbi);
- // viterbi_score = np.max(trellis[-1])
- // float viterbi_score = max(trellis[trellis.length - 1])) not used!
-
- // return viterbi, viterbi_score
return viterbi;
}
diff --git a/tf-ner-poc/src/test/resources/chars.txt.gz b/tf-ner-poc/src/test/resources/chars.txt.gz
new file mode 100644
index 0000000..c31b81a
--- /dev/null
+++ b/tf-ner-poc/src/test/resources/chars.txt.gz
Binary files differ
diff --git a/tf-ner-poc/src/test/resources/tags.txt.gz b/tf-ner-poc/src/test/resources/tags.txt.gz
new file mode 100644
index 0000000..0f0ceda
--- /dev/null
+++ b/tf-ner-poc/src/test/resources/tags.txt.gz
Binary files differ
diff --git a/tf-ner-poc/src/test/resources/words.txt.gz b/tf-ner-poc/src/test/resources/words.txt.gz
new file mode 100644
index 0000000..5f55ec0
--- /dev/null
+++ b/tf-ner-poc/src/test/resources/words.txt.gz
Binary files differ