blob: 7ec98b3b9a8c67a90bb412fce0d05c8edfb87f7e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.opennlp.namecat;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.opennlp.ModelUtil;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
public class NameCategorizer {
private final Session session;
private final Map<Character, Integer> charMap = new HashMap<>();
private final Map<Integer, String> labelMap;
public NameCategorizer(InputStream vocabChars, InputStream labelDict,
InputStream modelZipPackage) throws IOException {
try (BufferedReader in = new BufferedReader(new InputStreamReader(vocabChars,
StandardCharsets.UTF_8))) {
in.lines().forEach(ch -> charMap.put(ch.charAt(0), charMap.size()));
}
labelMap = new HashMap<>();
try (BufferedReader in = new BufferedReader(new InputStreamReader(labelDict,
StandardCharsets.UTF_8))) {
in.lines().forEach(label -> labelMap.put(labelMap.size(), label));
}
Path tmpModelPath = ModelUtil.writeModelToTmpDir(modelZipPackage);
SavedModelBundle model = SavedModelBundle.load(tmpModelPath.toString(), "serve");
session = model.session();
}
private static int argmax(float[] x) {
if (x == null || x.length == 0) {
throw new IllegalArgumentException("Vector x is null or empty");
}
int maxIdx = 0;
for (int i = 1; i < x.length; i++) {
if (x[maxIdx] < x[i])
maxIdx = i;
}
return maxIdx;
}
public String[] categorize(String[] names) {
if (names.length == 0) {
return new String[0];
}
int maxLength = Arrays.stream(names).mapToInt(String::length).max().getAsInt();
int charIds[][] = new int[names.length][maxLength];
int nameLengths[] = new int[names.length];
for (int nameIndex = 0; nameIndex < names.length; nameIndex++) {
for (int charIndex = 0; charIndex < names[nameIndex].length(); charIndex++) {
charIds[nameIndex][charIndex] = charMap.get(names[nameIndex].charAt(charIndex));
}
nameLengths[nameIndex] = names[nameIndex].length();
}
try (Tensor<?> dropout = Tensor.create(1f, Float.class);
Tensor<?> charTensor = Tensor.create(charIds);
Tensor<?> nameLength = Tensor.create(nameLengths)) {
List<Tensor<?>> result = session.runner()
.feed("dropout_keep_prop", dropout)
.feed("char_ids", charTensor)
.feed("name_lengths", nameLength)
.fetch("norm_probs", 0).run();
try (Tensor<?> probTensor = result.get(0)) {
float[][] probs = probTensor.copyTo(new float[names.length][labelMap.size()]);
return Arrays.stream(probs).map(prob -> labelMap.get(argmax(prob))).toArray(String[]::new);
}
}
}
}