blob: bd88f56bfe130651df72a6818d7820ef9ec3fa9e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.dl.doccat;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.IntStream;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import opennlp.dl.InferenceOptions;
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
import opennlp.tools.doccat.DocumentCategorizer;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
/**
* An implementation of {@link DocumentCategorizer} that performs document classification
* using ONNX models.
*/
public class DocumentCategorizerDL implements DocumentCategorizer {
public static final String INPUT_IDS = "input_ids";
public static final String ATTENTION_MASK = "attention_mask";
public static final String TOKEN_TYPE_IDS = "token_type_ids";
private final Tokenizer tokenizer;
private final Map<String, Integer> vocabulary;
private final Map<Integer, String> categories;
private final ClassificationScoringStrategy classificationScoringStrategy;
private final InferenceOptions inferenceOptions;
protected final OrtEnvironment env;
protected final OrtSession session;
/**
* Creates a new document categorizer using ONNX models.
* @param model The ONNX model file.
* @param vocab The model's vocabulary file.
* @param categories The categories.
* @param classificationScoringStrategy Implementation of {@link ClassificationScoringStrategy} used
* to calculate the classification scores given the score of each
* individual document part.
* @param inferenceOptions {@link InferenceOptions} to control the inference.
*/
public DocumentCategorizerDL(File model, File vocab, Map<Integer, String> categories,
ClassificationScoringStrategy classificationScoringStrategy,
InferenceOptions inferenceOptions)
throws IOException, OrtException {
this.env = OrtEnvironment.getEnvironment();
final OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
if (inferenceOptions.isGpu()) {
sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
}
this.session = env.createSession(model.getPath(), sessionOptions);
this.vocabulary = loadVocab(vocab);
this.tokenizer = new WordpieceTokenizer(vocabulary.keySet());
this.categories = categories;
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
}
@Override
public double[] categorize(String[] strings) {
try {
final List<Tokens> tokens = tokenize(strings[0]);
final List<double[]> scores = new LinkedList<>();
for (final Tokens t : tokens) {
final Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(t.getIds()), new long[] {1, t.getIds().length}));
if (inferenceOptions.isIncludeAttentionMask()) {
inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
LongBuffer.wrap(t.getMask()), new long[] {1, t.getMask().length}));
}
if (inferenceOptions.isIncludeTokenTypeIds()) {
inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(t.getTypes()), new long[] {1, t.getTypes().length}));
}
// The outputs from the model.
final float[][] v = (float[][]) session.run(inputs).get(0).getValue();
// Keep track of all scores.
final double[] categoryScoresForTokens = softmax(v[0]);
scores.add(categoryScoresForTokens);
}
return classificationScoringStrategy.score(scores);
} catch (Exception ex) {
System.err.println("Unload to perform document classification inference: " + ex.getMessage());
}
return new double[]{};
}
@Override
public double[] categorize(String[] strings, Map<String, Object> map) {
return categorize(strings);
}
@Override
public String getBestCategory(double[] doubles) {
return categories.get(maxIndex(doubles));
}
@Override
public int getIndex(String s) {
return getKey(s);
}
@Override
public String getCategory(int i) {
return categories.get(i);
}
@Override
public int getNumberOfCategories() {
return categories.size();
}
@Override
public String getAllResults(double[] doubles) {
return null;
}
@Override
public Map<String, Double> scoreMap(String[] strings) {
final double[] scores = categorize(strings);
final Map<String, Double> scoreMap = new HashMap<>();
for (int x : categories.keySet()) {
scoreMap.put(categories.get(x), scores[x]);
}
return scoreMap;
}
@Override
public SortedMap<Double, Set<String>> sortedScoreMap(String[] strings) {
final double[] scores = categorize(strings);
final SortedMap<Double, Set<String>> scoreMap = new TreeMap<>();
for (int x : categories.keySet()) {
if (scoreMap.get(scores[x]) == null) {
scoreMap.put(scores[x], new HashSet<>());
}
scoreMap.get(scores[x]).add(categories.get(x));
}
return scoreMap;
}
private int getKey(String value) {
for (Map.Entry<Integer, String> entry : categories.entrySet()) {
if (entry.getValue().equals(value)) {
return entry.getKey();
}
}
// The String wasn't found as a value in the map.
return -1;
}
/**
* Loads a vocabulary file from disk.
* @param vocab The vocabulary file.
* @return A map of vocabulary words to integer IDs.
* @throws IOException Thrown if the vocabulary file cannot be opened and read.
*/
private Map<String, Integer> loadVocab(File vocab) throws IOException {
final Map<String, Integer> v = new HashMap<>();
BufferedReader br = new BufferedReader(new FileReader(vocab.getPath()));
String line = br.readLine();
int x = 0;
while (line != null) {
line = br.readLine();
x++;
v.put(line, x);
}
return v;
}
private Tokens oldTokenize(String text) {
final String[] tokens = tokenizer.tokenize(text);
final int[] ids = new int[tokens.length];
for (int x = 0; x < tokens.length; x++) {
ids[x] = vocabulary.get(tokens[x]);
}
final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
final long[] mask = new long[ids.length];
Arrays.fill(mask, 1);
final long[] types = new long[ids.length];
Arrays.fill(types, 0);
return new Tokens(tokens, lids, mask, types);
}
private List<Tokens> tokenize(final String text) {
final List<Tokens> t = new LinkedList<>();
// In this article as the paper suggests, we are going to segment the input into smaller text and feed
// each of them into BERT, it means for each row, we will split the text in order to have some
// smaller text (200 words long each)
// https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd
// Split the input text into 200 word chunks with 50 overlapping between chunks.
final String[] whitespaceTokenized = text.split("\\s+");
for (int start = 0; start < whitespaceTokenized.length;
start = start + inferenceOptions.getDocumentSplitSize()) {
// 200 word length chunk
// Check the end do don't go past and get a StringIndexOutOfBoundsException
int end = start + inferenceOptions.getDocumentSplitSize();
if (end > whitespaceTokenized.length) {
end = whitespaceTokenized.length;
}
// The group is that subsection of string.
final String group = String.join(" ", Arrays.copyOfRange(whitespaceTokenized, start, end));
// We want to overlap each chunk by 50 words so scoot back 50 words for the next iteration.
start = start - inferenceOptions.getSplitOverlapSize();
// Now we can tokenize the group and continue.
final String[] tokens = tokenizer.tokenize(group);
final int[] ids = new int[tokens.length];
for (int x = 0; x < tokens.length; x++) {
ids[x] = vocabulary.get(tokens[x]);
}
final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
final long[] mask = new long[ids.length];
Arrays.fill(mask, 1);
final long[] types = new long[ids.length];
Arrays.fill(types, 0);
t.add(new Tokens(tokens, lids, mask, types));
}
return t;
}
/**
* Applies softmax to an array of values.
* @param input An array of values.
* @return The output array.
*/
private double[] softmax(final float[] input) {
final double[] t = new double[input.length];
double sum = 0.0;
for (int x = 0; x < input.length; x++) {
double val = Math.exp(input[x]);
sum += val;
t[x] = val;
}
final double[] output = new double[input.length];
for (int x = 0; x < output.length; x++) {
output[x] = (float) (t[x] / sum);
}
return output;
}
private int maxIndex(double[] arr) {
return IntStream.range(0, arr.length)
.reduce((i, j) -> arr[i] > arr[j] ? i : j)
.orElse(-1);
}
}