blob: db1d2bca2a8737c70ebf388872e76f9cea5391a0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.ml.naivebayes;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
/**
* Class implementing the probability distribution over labels returned by a classifier.
*
* @param <T> the label (category) class
*
*/
public abstract class Probabilities<T> {
protected Map<T, Double> map = new HashMap<>();
protected transient boolean isNormalised = false;
protected Map<T, Double> normalised;
protected double confidence = 0.0;
/**
* Assigns a probability to a label, discarding any previously assigned probability.
*
* @param t the label to which the probability is being assigned
* @param probability the probability to assign
*/
public void set(T t, double probability) {
isNormalised = false;
map.put(t, probability);
}
/**
* Assigns a probability to a label, discarding any previously assigned probability.
*
* @param t the label to which the probability is being assigned
* @param probability the probability to assign
*/
public void set(T t, Probability<T> probability) {
isNormalised = false;
map.put(t, probability.get());
}
/**
* Assigns a probability to a label, discarding any previously assigned probability,
* if the new probability is greater than the old one.
*
* @param t the label to which the probability is being assigned
* @param probability the probability to assign
*/
public void setIfLarger(T t, double probability) {
Double p = map.get(t);
if (p == null || probability > p) {
isNormalised = false;
map.put(t, probability);
}
}
/**
* Assigns a log probability to a label, discarding any previously assigned probability.
*
* @param t the label to which the log probability is being assigned
* @param probability the log probability to assign
*/
public void setLog(T t, double probability) {
set(t, StrictMath.exp(probability));
}
/**
* Compounds the existing probability mass on the label with the new probability passed in to the method.
*
* @param t the label whose probability mass is being updated
* @param probability the probability weight to add
* @param count the amplifying factor for the probability compounding
*/
public void addIn(T t, double probability, int count) {
isNormalised = false;
Double p = map.get(t);
if (p == null)
p = 1.0;
probability = StrictMath.pow(probability, count);
map.put(t, p * probability);
}
/**
* Returns the probability associated with a label
*
* @param t the label whose probability needs to be returned
* @return the probability associated with the label
*/
public Double get(T t) {
Double d = normalize().get(t);
if (d == null)
return 0.0;
return d;
}
/**
* Returns the log probability associated with a label
*
* @param t the label whose log probability needs to be returned
* @return the log probability associated with the label
*/
public Double getLog(T t) {
return StrictMath.log(get(t));
}
/**
* Returns the probabilities associated with all labels
*
* @return the HashMap of labels and their probabilities
*/
public Set<T> getKeys() {
return map.keySet();
}
/**
* Returns the probabilities associated with all labels
*
* @return the HashMap of labels and their probabilities
*/
public Map<T, Double> getAll() {
return normalize();
}
private Map<T, Double> normalize() {
if (isNormalised)
return normalised;
Map<T, Double> temp = createMapDataStructure();
double sum = 0;
for (Entry<T, Double> entry : map.entrySet()) {
Double p = entry.getValue();
if (p != null) {
sum += p;
}
}
for (Entry<T, Double> entry : temp.entrySet()) {
T t = entry.getKey();
Double p = entry.getValue();
if (p != null) {
temp.put(t, p / sum);
}
}
normalised = temp;
isNormalised = true;
return temp;
}
protected Map<T, Double> createMapDataStructure() {
return new HashMap<>();
}
/**
* Returns the most likely label
*
* @return the label that has the highest associated probability
*/
public T getMax() {
double max = 0;
T maxT = null;
for (Entry<T, Double> entry : map.entrySet()) {
final T t = entry.getKey();
final Double temp = entry.getValue();
if (temp >= max) {
max = temp;
maxT = t;
}
}
return maxT;
}
/**
* Returns the probability of the most likely label
*
* @return the highest probability
*/
public double getMaxValue() {
return get(getMax());
}
public void discardCountsBelow(double i) {
List<T> labelsToRemove = new ArrayList<>();
for (Entry<T, Double> entry : map.entrySet()) {
T label = entry.getKey();
Double sum = entry.getValue();
if (sum == null) sum = 0.0;
if (sum < i)
labelsToRemove.add(label);
}
for (T label : labelsToRemove) {
map.remove(label);
}
}
/**
* Returns the best confidence with which this set of probabilities has been calculated.
* This is a function of the amount of data that supports the assertion.
* It is also a measure of the accuracy of the estimator of the probability.
*
* @return the best confidence of the probabilities
*/
public double getConfidence() {
return confidence;
}
/**
* Sets the best confidence with which this set of probabilities has been calculated.
* This is a function of the amount of data that supports the assertion.
* It is also a measure of the accuracy of the estimator of the probability.
*
* @param confidence the confidence in the probabilities
*/
public void setConfidence(double confidence) {
this.confidence = confidence;
}
public String toString() {
return getAll().toString();
}
}