blob: a10b0e22d811bc8b1d1e3aa80bcc42efb4abaf4a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.langdetect;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.util.MutableInt;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
/**
* Implements learnable Language Detector
*
* <p>
* This will process the entire string when called with
* {@link #predictLanguage(CharSequence)} or
* {@link #predictLanguages(CharSequence)}.
* </p>
* <p>
* If you want this to stop early, use {@link #probingPredictLanguages(CharSequence)}
* or {@link #probingPredictLanguages(CharSequence, LanguageDetectorConfig)}.
* When run in probing mode, this starts at the beginning of the charsequence
* and runs language detection on chunks of text. If the end of the
* string is reached or there are {@link LanguageDetectorConfig#getMinConsecImprovements()}
* consecutive predictions for the best language and the confidence
* increases over those last predictions and if the difference
* in confidence between the highest confidence language
* and the second highest confidence language is greater than
* {@link LanguageDetectorConfig#getMinDiff()}, the language detector will
* stop and report the results.
* </p>
* <p>
* The authors wish to thank Ken Krugler and
* <a href="https://github.com/kkrugler/yalder">Yalder</a>}
* for the inspiration for many of the design
* components of this detector.
*
*/
public class LanguageDetectorME implements LanguageDetector {
protected LanguageDetectorModel model;
private LanguageDetectorContextGenerator mContextGenerator;
/**
* Initializes the current instance with a language detector model. Default feature
* generation is used.
*
* @param model the language detector model
*/
public LanguageDetectorME(LanguageDetectorModel model) {
this.model = model;
this.mContextGenerator = model.getFactory().getContextGenerator();
}
/**
* This will process the full content length.
*
* @param content
* @return the predicted languages
*/
@Override
public Language[] predictLanguages(CharSequence content) {
return predict(arrayToCounts(
mContextGenerator.getContext(content)));
}
/**
* This will process the full content length.
*
* @param content
* @return the language with the highest confidence
*/
@Override
public Language predictLanguage(CharSequence content) {
return predictLanguages(content)[0];
}
@Override
public String[] getSupportedLanguages() {
int numberLanguages = model.getMaxentModel().getNumOutcomes();
String[] languages = new String[numberLanguages];
for (int i = 0; i < numberLanguages; i++) {
languages[i] = model.getMaxentModel().getOutcome(i);
}
return languages;
}
/**
* This will stop processing early if the stopping criteria
* specified in {@link LanguageDetectorConfig#DEFAULT_LANGUAGE_DETECTOR_CONFIG}
* are met.
*
* @param content content to be processed
* @return result
*/
public ProbingLanguageDetectionResult probingPredictLanguages(CharSequence content) {
return probingPredictLanguages(content,
LanguageDetectorConfig.DEFAULT_LANGUAGE_DETECTOR_CONFIG);
}
/**
* This will stop processing early if the stopping criteria
* specified in {@link LanguageDetectorConfig#DEFAULT_LANGUAGE_DETECTOR_CONFIG}
* are met.
*
* @param content content to process
* @param config config to customize detection
* @return
*/
public ProbingLanguageDetectionResult probingPredictLanguages(CharSequence content,
LanguageDetectorConfig config) {
//list of the languages that received the highest
//confidence over the last n chunk detections
List<Language[]> predictions = new LinkedList();
int start = 0;//where to start the next chunk in codepoints
Language[] currPredictions = null;
//cache ngram counts across chunks
Map<String, MutableInt> ngramCounts = new HashMap<>();
while (true) {
int actualChunkSize =
(start + config.getChunkSize() > config.getMaxLength()) ?
config.getMaxLength() - start : config.getChunkSize();
StringCPLengthPair chunk = chunk(content, start, actualChunkSize);
if (chunk.length() == 0) {
if (currPredictions == null) {
return new ProbingLanguageDetectionResult(predict(ngramCounts), start);
} else {
return new ProbingLanguageDetectionResult(currPredictions, start);
}
}
start += chunk.length();
updateCounts(mContextGenerator.getContext(chunk.getString()), ngramCounts);
currPredictions = predict(ngramCounts);
if (seenEnough(predictions, currPredictions, ngramCounts, config)) {
return new ProbingLanguageDetectionResult(currPredictions, start);
}
}
}
private void updateCounts(String[] context, Map<String, MutableInt> ngrams) {
for (String ngram : context) {
MutableInt i = ngrams.get(ngram);
if (i == null) {
i = new MutableInt(1);
ngrams.put(ngram, i);
} else {
i.increment();
}
}
}
private Map<String, MutableInt> arrayToCounts(String[] context) {
Map<String, MutableInt> ngrams = new HashMap<>();
updateCounts(context, ngrams);
return ngrams;
}
private Language[] predict(Map<String, MutableInt> ngramCounts) {
String[] allGrams = new String[ngramCounts.size()];
float[] counts = new float[ngramCounts.size()];
int i = 0;
for (Map.Entry<String, MutableInt> e : ngramCounts.entrySet()) {
allGrams[i] = e.getKey();
// TODO -- once OPENNLP-1261 is fixed,
// change this to e.getValue().getValue().
counts[i] = 1;
i++;
}
double[] eval = model.getMaxentModel().eval(allGrams, counts);
Language[] arr = new Language[eval.length];
for (int j = 0; j < eval.length; j++) {
arr[j] = new Language(model.getMaxentModel().getOutcome(j), eval[j]);
}
Arrays.sort(arr, (o1, o2) -> Double.compare(o2.getConfidence(), o1.getConfidence()));
return arr;
}
/**
* Override this for different behavior to determine if there is enough
* confidence in the predictions to stop.
*
* @param predictionsQueue queue of earlier predictions
* @param newPredictions most recent predictions
* @param ngramCounts -- not currently used, but might be useful
* @return whether or not enough text has been processed to make a determination
*/
boolean seenEnough(List<Language[]> predictionsQueue, Language[] newPredictions,
Map<String, MutableInt> ngramCounts, LanguageDetectorConfig config) {
if (predictionsQueue.size() < config.getMinConsecImprovements()) {
predictionsQueue.add(newPredictions);
return false;
} else if (predictionsQueue.size() > config.getMinConsecImprovements()
&& predictionsQueue.size() > 0) {
predictionsQueue.remove(0);
}
predictionsQueue.add(newPredictions);
if (config.getMinDiff() > 0.0 &&
newPredictions[0].getConfidence() -
newPredictions[1].getConfidence() < config.getMinDiff()) {
return false;
}
String lastLang = null;
double lastConf = -1.0;
//iterate through the last predictions
//and check that the lang with the highest confidence
//hasn't changed, and that the confidence in it
//hasn't decreased
for (Language[] predictions : predictionsQueue) {
if (lastLang == null) {
lastLang = predictions[0].getLang();
lastConf = predictions[0].getConfidence();
continue;
} else {
if (!lastLang.equals(predictions[0].getLang())) {
return false;
}
if (lastConf > predictions[0].getConfidence()) {
return false;
}
}
lastLang = predictions[0].getLang();
lastConf = predictions[0].getConfidence();
}
return true;
}
private StringCPLengthPair chunk(CharSequence content, int start, int chunkSize) {
if (start == 0 && chunkSize > content.length()) {
String s = content.toString();
int codePointLength = s.codePointCount(0, s.length());
return
new StringCPLengthPair(s, codePointLength);
}
int[] codepoints = content.codePoints().skip(start).limit(chunkSize).toArray();
return
new StringCPLengthPair(
new String(codepoints, 0, codepoints.length),
codepoints.length);
}
public static LanguageDetectorModel train(ObjectStream<LanguageSample> samples,
TrainingParameters mlParams,
LanguageDetectorFactory factory)
throws IOException {
Map<String, String> manifestInfoEntries = new HashMap<>();
mlParams.putIfAbsent(AbstractEventTrainer.DATA_INDEXER_PARAM,
AbstractEventTrainer.DATA_INDEXER_ONE_PASS_VALUE);
EventTrainer trainer = TrainerFactory.getEventTrainer(
mlParams, manifestInfoEntries);
MaxentModel model = trainer.train(
new LanguageDetectorEventStream(samples, factory.getContextGenerator()));
return new LanguageDetectorModel(model, manifestInfoEntries, factory);
}
private static class StringCPLengthPair {
private final String s;
private final int length;
StringCPLengthPair(String s, int length) {
this.s = s;
this.length = length;
}
int length() {
return length;
}
String getString() {
return s;
}
}
}