blob: 6f855fc617012c42687aa63f1c5628f7e2a878e8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package opennlp.tools.languagemodel;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Random;
import opennlp.tools.ngram.NGramUtils;
import opennlp.tools.util.StringList;
import org.junit.Ignore;
/**
* Utility class for language models tests
*/
@Ignore
public class LanguageModelTestUtils {
private static final java.math.MathContext CONTEXT = MathContext.DECIMAL128;
private static Random r = new Random();
private static final char[] chars = new char[]{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'};
public static Collection<StringList> generateRandomVocabulary(int size) {
Collection<StringList> vocabulary = new LinkedList<>();
for (int i = 0; i < size; i++) {
StringList sentence = generateRandomSentence();
vocabulary.add(sentence);
}
return vocabulary;
}
public static StringList generateRandomSentence() {
int dimension = r.nextInt(10) + 1;
String[] sentence = new String[dimension];
for (int j = 0; j < dimension; j++) {
int i = r.nextInt(10);
char c = chars[i];
sentence[j] = c + "-" + c + "-" + c;
}
return new StringList(sentence);
}
public static double getPerplexity(LanguageModel lm, Collection<StringList> testSet, int ngramSize) throws ArithmeticException {
BigDecimal perplexity = new BigDecimal(1d);
for (StringList sentence : testSet) {
for (StringList ngram : NGramUtils.getNGrams(sentence, ngramSize)) {
double ngramProbability = lm.calculateProbability(ngram);
perplexity = perplexity.multiply(new BigDecimal(1d).divide(new BigDecimal(ngramProbability), CONTEXT));
}
}
double p = Math.log(perplexity.doubleValue());
if (Double.isInfinite(p) || Double.isNaN(p)) {
return Double.POSITIVE_INFINITY; // over/underflow -> too high perplexity
} else {
BigDecimal log = new BigDecimal(p);
return Math.pow(Math.E, log.divide(new BigDecimal(testSet.size()), CONTEXT).doubleValue());
}
}
}