blob: ff132897e6126cd0e97ea1bcf36b2daac0fd271f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.dl.namefinder;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;
/**
* An implementation of {@link TokenNameFinder} that uses ONNX models.
*/
public class NameFinderDL implements TokenNameFinder {
public static final String INPUT_IDS = "input_ids";
public static final String ATTENTION_MASK = "attention_mask";
public static final String TOKEN_TYPE_IDS = "token_type_ids";
public static final String I_PER = "I-PER";
public static final String B_PER = "B-PER";
protected final OrtSession session;
private final Map<Integer, String> ids2Labels;
private final Tokenizer tokenizer;
private final Map<String, Integer> vocab;
private final InferenceOptions inferenceOptions;
protected final OrtEnvironment env;
public NameFinderDL(File model, File vocabulary, Map<Integer, String> ids2Labels) throws Exception {
this(model, vocabulary, ids2Labels, new InferenceOptions());
}
public NameFinderDL(File model, File vocabulary, Map<Integer, String> ids2Labels,
InferenceOptions inferenceOptions) throws Exception {
this.env = OrtEnvironment.getEnvironment();
final OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
if (inferenceOptions.isGpu()) {
sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
}
this.session = env.createSession(model.getPath(), sessionOptions);
this.ids2Labels = ids2Labels;
this.vocab = loadVocab(vocabulary);
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
this.inferenceOptions = inferenceOptions;
}
@Override
public Span[] find(String[] input) {
/**
* So, it looks like inference is being done on the wordpiece tokens but then
* spans are being created from the whitespace tokens.
*/
final List<Span> spans = new LinkedList<>();
// Join the tokens here because they will be tokenized using Wordpiece during inference.
final String text = String.join(" ", input);
// The WordPiece tokenized text. This changes the spacing in the text.
final List<Tokens> wordpieceTokens = tokenize(text);
for (final Tokens tokens : wordpieceTokens) {
try {
// The inputs to the ONNX model.
final Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put(INPUT_IDS, OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.getIds()),
new long[] {1, tokens.getIds().length}));
if (inferenceOptions.isIncludeAttentionMask()) {
inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.getMask()), new long[] {1, tokens.getMask().length}));
}
if (inferenceOptions.isIncludeTokenTypeIds()) {
inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.getTypes()), new long[] {1, tokens.getTypes().length}));
}
// The outputs from the model.
final float[][][] v = (float[][][]) session.run(inputs).get(0).getValue();
// Find consecutive B-PER and I-PER labels and combine the spans where necessary.
// There are also B-LOC and I-LOC tags for locations that might be useful at some point.
// Keep track of where the last span was so when there are multiple/duplicate
// spans we can get the next one instead of the first one each time.
int characterStart = 0;
// We are looping over the vector for each word,
// finding the index of the array that has the maximum value,
// and then finding the token classification that corresponds to that index.
for (int x = 0; x < v[0].length; x++) {
final float[] arr = v[0][x];
final int maxIndex = maxIndex(arr);
final String label = ids2Labels.get(maxIndex);
// TODO: Need to make sure this value is between 0 and 1?
// Can we do thresholding without it between 0 and 1?
final double confidence = arr[maxIndex]; // / 10;
// Show each token and its label per the model.
// System.out.println(tokens.getTokens()[x] + " : " + label);
// Is this is the start of a person entity.
if (B_PER.equals(label)) {
final String spanText;
// Find the end index of the span in the array (where the label is not I-PER).
final SpanEnd spanEnd = findSpanEnd(v, x, ids2Labels, tokens.getTokens());
// If the end is -1 it means this is a single-span token.
// If the end is != -1 it means this is a multi-span token.
if (spanEnd.getIndex() != -1) {
final StringBuilder sb = new StringBuilder();
// We have to concatenate the tokens.
// Add each token in the array and separate them with a space.
// We'll separate each with a single space because later we'll find the original span
// in the text and ignore spacing between individual tokens in findByRegex().
int end = spanEnd.getIndex();
for (int i = x; i <= end; i++) {
// If the next token starts with ##, combine it with this token.
if (tokens.getTokens()[i + 1].startsWith("##")) {
sb.append(tokens.getTokens()[i] + tokens.getTokens()[i + 1].replaceAll("##", ""));
// Append a space unless the next (next) token starts with ##.
if (!tokens.getTokens()[i + 2].startsWith("##")) {
sb.append(" ");
}
// Skip the next token since we just included it in this iteration.
i++;
} else {
sb.append(tokens.getTokens()[i].replaceAll("##", ""));
// Append a space unless the next token is a period.
if (!".".equals(tokens.getTokens()[i + 1])) {
sb.append(" ");
}
}
}
// This is the text of the span. We use the whole original input text and not one
// of the splits. This gives us accurate character positions.
spanText = findByRegex(text, sb.toString().trim()).trim();
} else {
// This is a single-token span so there is nothing else to do except grab the token.
spanText = tokens.getTokens()[x];
}
// This ignores other potential matches in the same sentence
// by only taking the first occurrence.
characterStart = text.indexOf(spanText, characterStart);
final int characterEnd = characterStart + spanText.length();
spans.add(new Span(characterStart, characterEnd, spanText, confidence));
characterStart = characterEnd;
}
}
} catch (OrtException ex) {
throw new RuntimeException("Error performing namefinder inference: " + ex.getMessage(), ex);
}
}
return spans.toArray(new Span[0]);
}
@Override
public void clearAdaptiveData() {
// No use in this implementation.
}
private SpanEnd findSpanEnd(float[][][] v, int startIndex, Map<Integer, String> id2Labels,
String[] tokens) {
// -1 means there is no follow-up token, so it is a single-token span.
int index = -1;
int characterEnd = 0;
// Starts at the span start in the vector.
// Looks at the next token to see if it is an I-PER.
// Go until the next token is something other than I-PER.
// When the next token is not I-PER, return the previous index.
for (int x = startIndex + 1; x < v[0].length; x++) {
// Get the next item.
final float[] arr = v[0][x];
// See if the next token has an I-PER label.
final String nextTokenClassification = id2Labels.get(maxIndex(arr));
if (!I_PER.equals(nextTokenClassification)) {
index = x - 1;
break;
}
}
// Find where the span ends based on the tokens.
for (int x = 1; x <= index && x < tokens.length; x++) {
characterEnd += tokens[x].length();
}
// Account for the number of spaces (that is the number of tokens).
// (One space per token.)
characterEnd += index - 1;
return new SpanEnd(index, characterEnd);
}
private int maxIndex(float[] arr) {
double max = Float.NEGATIVE_INFINITY;
int index = -1;
for (int x = 0; x < arr.length; x++) {
if (arr[x] > max) {
index = x;
max = arr[x];
}
}
return index;
}
private static String findByRegex(String text, String span) {
final String regex = span
.replaceAll(" ", "\\\\s+")
.replaceAll("\\)", "\\\\)")
.replaceAll("\\(", "\\\\(");
final Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
final Matcher matcher = pattern.matcher(text);
if (matcher.find()) {
return matcher.group(0);
}
// For some reason the regex match wasn't found. Just return the original span.
return span;
}
private List<Tokens> tokenize(final String text) {
final List<Tokens> t = new LinkedList<>();
// In this article as the paper suggests, we are going to segment the input into smaller text and feed
// each of them into BERT, it means for each row, we will split the text in order to have some
// smaller text (200 words long each)
// https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd
// Split the input text into 200 word chunks with 50 overlapping between chunks.
final String[] whitespaceTokenized = text.split("\\s+");
for (int start = 0; start < whitespaceTokenized.length;
start = start + inferenceOptions.getDocumentSplitSize()) {
// 200 word length chunk
// Check the end do don't go past and get a StringIndexOutOfBoundsException
int end = start + inferenceOptions.getDocumentSplitSize();
if (end > whitespaceTokenized.length) {
end = whitespaceTokenized.length;
}
// The group is that subsection of string.
final String group = String.join(" ", Arrays.copyOfRange(whitespaceTokenized, start, end));
// We want to overlap each chunk by 50 words so scoot back 50 words for the next iteration.
start = start - inferenceOptions.getSplitOverlapSize();
// Now we can tokenize the group and continue.
final String[] tokens = tokenizer.tokenize(group);
final int[] ids = new int[tokens.length];
for (int x = 0; x < tokens.length; x++) {
ids[x] = vocab.get(tokens[x]);
}
final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
final long[] mask = new long[ids.length];
Arrays.fill(mask, 1);
final long[] types = new long[ids.length];
Arrays.fill(types, 0);
t.add(new Tokens(tokens, lids, mask, types));
}
return t;
}
/**
* Loads a vocabulary file from disk.
* @param vocab The vocabulary file.
* @return A map of vocabulary words to integer IDs.
* @throws IOException Thrown if the vocabulary file cannot be opened and read.
*/
private Map<String, Integer> loadVocab(File vocab) throws IOException {
final Map<String, Integer> v = new HashMap<>();
try (final BufferedReader br = new BufferedReader(new FileReader(vocab.getPath()))) {
String line = br.readLine();
int x = 0;
while (line != null) {
line = br.readLine();
x++;
v.put(line, x);
}
}
return v;
}
}