blob: eee5cbfe853c96baed83250d52f91f584ae8079b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.ml.naivebayes;
import java.util.ArrayList;
import java.util.Map;
import java.util.Map.Entry;
/**
* Class implementing the probability distribution over labels returned by
* a classifier as a log of probabilities.
* This is necessary because floating point precision in Java does not allow for high-accuracy
* representation of very low probabilities such as would occur in a text categorizer.
*
* @param <T> the label (category) class
*
*/
public class LogProbabilities<T> extends Probabilities<T> {
/**
* Assigns a probability to a label, discarding any previously assigned probability.
*
* @param t the label to which the probability is being assigned
* @param probability the probability to assign
*/
public void set(T t, double probability) {
isNormalised = false;
map.put(t, log(probability));
}
/**
* Assigns a probability to a label, discarding any previously assigned probability.
*
* @param t the label to which the probability is being assigned
* @param probability the probability to assign
*/
public void set(T t, Probability<T> probability) {
isNormalised = false;
map.put(t, probability.getLog());
}
/**
* Assigns a probability to a label, discarding any previously assigned probability,
* if the new probability is greater than the old one.
*
* @param t the label to which the probability is being assigned
* @param probability the probability to assign
*/
public void setIfLarger(T t, double probability) {
double logProbability = log(probability);
Double p = map.get(t);
if (p == null || logProbability > p) {
isNormalised = false;
map.put(t, logProbability);
}
}
/**
* Assigns a log probability to a label, discarding any previously assigned probability.
*
* @param t the label to which the log probability is being assigned
* @param probability the log probability to assign
*/
public void setLog(T t, double probability) {
isNormalised = false;
map.put(t, probability);
}
/**
* Compounds the existing probability mass on the label with the new probability passed in to the method.
*
* @param t the label whose probability mass is being updated
* @param probability the probability weight to add
* @param count the amplifying factor for the probability compounding
*/
public void addIn(T t, double probability, int count) {
isNormalised = false;
Double p = map.get(t);
if (p == null)
p = 0.0;
probability = log(probability) * count;
map.put(t, p + probability);
}
private Map<T, Double> normalize() {
if (isNormalised)
return normalised;
Map<T, Double> temp = createMapDataStructure();
double highestLogProbability = Double.NEGATIVE_INFINITY;
for (Entry<T, Double> entry : map.entrySet()) {
final Double p = entry.getValue();
if (p != null && p > highestLogProbability) {
highestLogProbability = p;
}
}
double sum = 0;
for (Entry<T, Double> entry : map.entrySet()) {
T t = entry.getKey();
Double p = entry.getValue();
if (p != null) {
double temp_p = StrictMath.exp(p - highestLogProbability);
if (!Double.isNaN(temp_p)) {
sum += temp_p;
temp.put(t, temp_p);
}
}
}
for (Entry<T, Double> entry : temp.entrySet()) {
final T t = entry.getKey();
final Double p = entry.getValue();
if (p != null && sum > Double.MIN_VALUE) {
temp.put(t, p / sum);
}
}
normalised = temp;
isNormalised = true;
return temp;
}
private double log(double prob) {
return StrictMath.log(prob);
}
/**
* Returns the probability associated with a label
*
* @param t the label whose probability needs to be returned
* @return the probability associated with the label
*/
public Double get(T t) {
Double d = normalize().get(t);
if (d == null)
return 0.0;
return d;
}
/**
* Returns the log probability associated with a label
*
* @param t the label whose log probability needs to be returned
* @return the log probability associated with the label
*/
public Double getLog(T t) {
Double d = map.get(t);
if (d == null)
return Double.NEGATIVE_INFINITY;
return d;
}
public void discardCountsBelow(double i) {
i = StrictMath.log(i);
ArrayList<T> labelsToRemove = new ArrayList<>();
for (Entry<T, Double> entry : map.entrySet()) {
final T label = entry.getKey();
Double sum = entry.getValue();
if (sum == null) sum = Double.NEGATIVE_INFINITY;
if (sum < i)
labelsToRemove.add(label);
}
for (T label : labelsToRemove) {
map.remove(label);
}
}
/**
* Returns the probabilities associated with all labels
*
* @return the HashMap of labels and their probabilities
*/
public Map<T, Double> getAll() {
return normalize();
}
/**
* Returns the most likely label
*
* @return the label that has the highest associated probability
*/
public T getMax() {
double max = Double.NEGATIVE_INFINITY;
T maxT = null;
for (Entry<T, Double> entry : map.entrySet()) {
final T t = entry.getKey();
final Double temp = entry.getValue();
if (temp >= max) {
max = temp;
maxT = t;
}
}
return maxT;
}
}