blob: bd9a0d9bd3cdeda414e21efa3148481c24104fa4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification.utils;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermRangeQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NamedThreadFactory;
/**
* Utility class to generate the confusion matrix of a {@link Classifier}
*/
public class ConfusionMatrixGenerator {
private ConfusionMatrixGenerator() {
}
/**
* get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
* generated on the given {@link IndexReader}, class and text fields.
*
* @param reader the {@link IndexReader} containing the index used for creating the {@link Classifier}
* @param classifier the {@link Classifier} whose confusion matrix has to be generated
* @param classFieldName the name of the Lucene field used as the classifier's output
* @param textFieldName the nome the Lucene field used as the classifier's input
* @param timeoutMilliseconds timeout to wait before stopping creating the confusion matrix
* @param <T> the return type of the {@link ClassificationResult} returned by the given {@link Classifier}
* @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
* @throws IOException if problems occurr while reading the index or using the classifier
*/
public static <T> ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier<T> classifier, String classFieldName,
String textFieldName, long timeoutMilliseconds) throws IOException {
ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-"));
try {
Map<String, Map<String, Long>> counts = new HashMap<>();
IndexSearcher indexSearcher = new IndexSearcher(reader);
TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE);
double time = 0d;
int counter = 0;
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
if (timeoutMilliseconds > 0 && time >= timeoutMilliseconds) {
break;
}
Document doc = reader.document(scoreDoc.doc);
String[] correctAnswers = doc.getValues(classFieldName);
if (correctAnswers != null && correctAnswers.length > 0) {
Arrays.sort(correctAnswers);
ClassificationResult<T> result;
String text = doc.get(textFieldName);
if (text != null) {
try {
// fail if classification takes more than 5s
long start = System.currentTimeMillis();
result = executorService.submit(() -> classifier.assignClass(text)).get(5, TimeUnit.SECONDS);
long end = System.currentTimeMillis();
time += end - start;
if (result != null) {
T assignedClass = result.getAssignedClass();
if (assignedClass != null) {
counter++;
String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
String correctAnswer;
if (Arrays.binarySearch(correctAnswers, classified) >= 0) {
correctAnswer = classified;
} else {
correctAnswer = correctAnswers[0];
}
Map<String, Long> stringLongMap = counts.get(correctAnswer);
if (stringLongMap != null) {
Long aLong = stringLongMap.get(classified);
if (aLong != null) {
stringLongMap.put(classified, aLong + 1);
} else {
stringLongMap.put(classified, 1L);
}
} else {
stringLongMap = new HashMap<>();
stringLongMap.put(classified, 1L);
counts.put(correctAnswer, stringLongMap);
}
}
}
} catch (TimeoutException timeoutException) {
// add classification timeout
time += 5000;
} catch (ExecutionException | InterruptedException executionException) {
throw new RuntimeException(executionException);
}
}
}
}
return new ConfusionMatrix(counts, time / counter, counter);
} finally {
executorService.shutdown();
}
}
/**
* a confusion matrix, backed by a {@link Map} representing the linearized matrix
*/
public static class ConfusionMatrix {
private final Map<String, Map<String, Long>> linearizedMatrix;
private final double avgClassificationTime;
private final int numberOfEvaluatedDocs;
private double accuracy = -1d;
private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) {
this.linearizedMatrix = linearizedMatrix;
this.avgClassificationTime = avgClassificationTime;
this.numberOfEvaluatedDocs = numberOfEvaluatedDocs;
}
/**
* get the linearized confusion matrix as a {@link Map}
*
* @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers'
* counts
*/
public Map<String, Map<String, Long>> getLinearizedMatrix() {
return Collections.unmodifiableMap(linearizedMatrix);
}
/**
* calculate precision on the given class
*
* @param klass the class to calculate the precision for
* @return the precision for the given class
*/
public double getPrecision(String klass) {
Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0;
double den = 0; // tp + fp
if (classifications != null) {
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
den += values.get(klass);
}
}
}
return tp > 0 ? tp / den : 0;
}
/**
* calculate recall on the given class
*
* @param klass the class to calculate the recall for
* @return the recall for the given class
*/
public double getRecall(String klass) {
Map<String, Long> classifications = linearizedMatrix.get(klass);
double tp = 0;
double fn = 0;
if (classifications != null) {
for (Map.Entry<String, Long> entry : classifications.entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
} else {
fn += entry.getValue();
}
}
}
return tp + fn > 0 ? tp / (tp + fn) : 0;
}
/**
* get the F-1 measure of the given class
*
* @param klass the class to calculate the F-1 measure for
* @return the F-1 measure for the given class
*/
public double getF1Measure(String klass) {
double recall = getRecall(klass);
double precision = getPrecision(klass);
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
}
/**
* get the F-1 measure on this confusion matrix
*
* @return the F-1 measure
*/
public double getF1Measure() {
double recall = getRecall();
double precision = getPrecision();
return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
}
/**
* Calculate accuracy on this confusion matrix using the formula:
* {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)}
*
* @return the accuracy
*/
public double getAccuracy() {
if (this.accuracy == -1) {
double tp = 0d;
double tn = 0d;
double tfp = 0d; // tp + fp
double fn = 0d;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) {
if (klass.equals(entry.getKey())) {
tp += entry.getValue();
} else {
fn += entry.getValue();
}
}
for (Map<String, Long> values : linearizedMatrix.values()) {
if (values.containsKey(klass)) {
tfp += values.get(klass);
} else {
tn++;
}
}
}
this.accuracy = (tp + tn) / (tfp + fn + tn);
}
return this.accuracy;
}
/**
* get the macro averaged precision (see {@link #getPrecision(String)}) over all the classes.
*
* @return the macro averaged precision as computed from the confusion matrix
*/
public double getPrecision() {
double p = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
p += getPrecision(klass);
}
return p / linearizedMatrix.size();
}
/**
* get the macro averaged recall (see {@link #getRecall(String)}) over all the classes
*
* @return the recall as computed from the confusion matrix
*/
public double getRecall() {
double r = 0;
for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) {
String klass = classification.getKey();
r += getRecall(klass);
}
return r / linearizedMatrix.size();
}
@Override
public String toString() {
return "ConfusionMatrix{" +
"linearizedMatrix=" + linearizedMatrix +
", avgClassificationTime=" + avgClassificationTime +
", numberOfEvaluatedDocs=" + numberOfEvaluatedDocs +
'}';
}
/**
* get the average classification time in milliseconds
*
* @return the avg classification time
*/
public double getAvgClassificationTime() {
return avgClassificationTime;
}
/**
* get the no. of documents evaluated while generating this confusion matrix
*
* @return the no. of documents evaluated
*/
public int getNumberOfEvaluatedDocs() {
return numberOfEvaluatedDocs;
}
}
}