blob: f05a09d2f23f7f268fab0c8c63dd7577590aac45 [file] [log] [blame]
package org.apache.opennlp.tf.guillaumegenthial;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPInputStream;
public class SequenceTagging implements AutoCloseable {
private final SavedModelBundle model;
private final Session session;
private final WordIndexer wordIndexer;
private final IndexTagger indexTagger;
public SequenceTagging(PredictionConfiguration config) throws IOException {
this.model = SavedModelBundle.load(config.getSavedModel(), "serve");
this.session = model.session();
this.wordIndexer = new WordIndexer(new GZIPInputStream(new FileInputStream(config.getVocabWords())),
new GZIPInputStream(new FileInputStream(config.getVocabChars())));
this.indexTagger = new IndexTagger(new GZIPInputStream(new FileInputStream(config.getVocabTags())));
}
public String[] predict(String[] sentence) {
TokenIds tokenIds = wordIndexer.toTokenIds(sentence);
return predict(tokenIds)[0];
}
public String[][] predict(String[][] sentences) {
TokenIds tokenIds = wordIndexer.toTokenIds(sentences);
return predict(tokenIds);
}
private String[][] predict(TokenIds tokenIds) {
FeedDictionary fd = FeedDictionary.create(tokenIds);
List<Tensor<?>> run = session.runner()
.feed("char_ids:0", fd.getCharIdsTensor())
.feed("dropout:0", fd.getDropoutTensor())
.feed("sequence_lengths:0", fd.getSentenceLengthsTensor())
.feed("word_ids:0", fd.getWordIdsTensor())
.feed("word_lengths:0", fd.getWordLengthsTensor())
.fetch("proj/logits", 0)
.fetch("trans_params", 0).run();
float[][][] logits = new float[fd.getNumberOfSentences()][fd.getMaxSentenceLength()][indexTagger.getNumberOfTags()];
run.get(0).copyTo(logits);
float[][] trans_params = new float[indexTagger.getNumberOfTags()][indexTagger.getNumberOfTags()];
run.get(1).copyTo(trans_params);
String[][] returnValue = new String[fd.getNumberOfSentences()][];
for (int i=0; i < logits.length; i++) {
//logit = logit[:sequence_length] # keep only the valid steps
float[][] logit = Arrays.copyOf(logits[i], fd.getSentenceLengths()[i]);
returnValue[i] = Viterbi.decode(logit, trans_params).stream().map(indexTagger::getTag).toArray(String[]::new);
}
return returnValue;
}
@Override
public void close() throws Exception {
session.close();
}
}