| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package opennlp.tools.cmdline; |
| |
| import java.io.OutputStream; |
| import java.io.PrintStream; |
| import java.text.MessageFormat; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.Iterator; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Map.Entry; |
| import java.util.Set; |
| import java.util.SortedSet; |
| import java.util.TreeSet; |
| |
| import opennlp.tools.util.Span; |
| import opennlp.tools.util.eval.FMeasure; |
| import opennlp.tools.util.eval.Mean; |
| |
| public abstract class FineGrainedReportListener { |
| |
| private static final char[] alpha = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', |
| 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', |
| 'w', 'x', 'y', 'z' }; |
| private final PrintStream printStream; |
| private final Stats stats = new Stats(); |
| |
| public FineGrainedReportListener(PrintStream printStream) { |
| this.printStream = printStream; |
| } |
| |
| /** |
| * Writes the report to the {@link OutputStream}. Should be called only after |
| * the evaluation process |
| */ |
| public FineGrainedReportListener(OutputStream outputStream) { |
| this.printStream = new PrintStream(outputStream); |
| } |
| |
| private static String generateAlphaLabel(int index) { |
| |
| char[] labelChars = new char[3]; |
| int i; |
| |
| for (i = 2; i >= 0; i--) { |
| if (index >= 0) { |
| labelChars[i] = alpha[index % alpha.length]; |
| index = index / alpha.length - 1; |
| } else { |
| labelChars[i] = ' '; |
| } |
| } |
| return new String(labelChars); |
| } |
| |
| public abstract void writeReport(); |
| |
| |
| // api methods |
| // general stats |
| |
| protected Stats getStats() { |
| return this.stats; |
| } |
| |
| private long getNumberOfSentences() { |
| return stats.getNumberOfSentences(); |
| } |
| |
| private double getAverageSentenceSize() { |
| return stats.getAverageSentenceSize(); |
| } |
| |
| private int getMinSentenceSize() { |
| return stats.getMinSentenceSize(); |
| } |
| |
| private int getMaxSentenceSize() { |
| return stats.getMaxSentenceSize(); |
| } |
| |
| private int getNumberOfTags() { |
| return stats.getNumberOfTags(); |
| } |
| |
| // token stats |
| |
| private double getAccuracy() { |
| return stats.getAccuracy(); |
| } |
| |
| private double getTokenAccuracy(String token) { |
| return stats.getTokenAccuracy(token); |
| } |
| |
| private SortedSet<String> getTokensOrderedByFrequency() { |
| return stats.getTokensOrderedByFrequency(); |
| } |
| |
| private int getTokenFrequency(String token) { |
| return stats.getTokenFrequency(token); |
| } |
| |
| private int getTokenErrors(String token) { |
| return stats.getTokenErrors(token); |
| } |
| |
| private SortedSet<String> getTokensOrderedByNumberOfErrors() { |
| return stats.getTokensOrderedByNumberOfErrors(); |
| } |
| |
| private SortedSet<String> getTagsOrderedByErrors() { |
| return stats.getTagsOrderedByErrors(); |
| } |
| |
| private int getTagFrequency(String tag) { |
| return stats.getTagFrequency(tag); |
| } |
| |
| private int getTagErrors(String tag) { |
| return stats.getTagErrors(tag); |
| } |
| |
| private double getTagPrecision(String tag) { |
| return stats.getTagPrecision(tag); |
| } |
| |
| private double getTagRecall(String tag) { |
| return stats.getTagRecall(tag); |
| } |
| |
| private double getTagFMeasure(String tag) { |
| return stats.getTagFMeasure(tag); |
| } |
| |
| private SortedSet<String> getConfusionMatrixTagset() { |
| return stats.getConfusionMatrixTagset(); |
| } |
| |
| private SortedSet<String> getConfusionMatrixTagset(String token) { |
| return stats.getConfusionMatrixTagset(token); |
| } |
| |
| private double[][] getConfusionMatrix() { |
| return stats.getConfusionMatrix(); |
| } |
| |
| private double[][] getConfusionMatrix(String token) { |
| return stats.getConfusionMatrix(token); |
| } |
| |
| private String matrixToString(SortedSet<String> tagset, double[][] data, |
| boolean filter) { |
| // we dont want to print trivial cases (acc=1) |
| int initialIndex = 0; |
| String[] tags = tagset.toArray(new String[tagset.size()]); |
| StringBuilder sb = new StringBuilder(); |
| int minColumnSize = Integer.MIN_VALUE; |
| String[][] matrix = new String[data.length][data[0].length]; |
| for (int i = 0; i < data.length; i++) { |
| int j = 0; |
| for (; j < data[i].length - 1; j++) { |
| matrix[i][j] = data[i][j] > 0 ? Integer.toString((int) data[i][j]) |
| : "."; |
| if (minColumnSize < matrix[i][j].length()) { |
| minColumnSize = matrix[i][j].length(); |
| } |
| } |
| matrix[i][j] = MessageFormat.format("{0,number,#.##%}", data[i][j]); |
| if (data[i][j] == 1 && filter) { |
| initialIndex = i + 1; |
| } |
| } |
| |
| final String headerFormat = "%" + (minColumnSize + 2) + "s "; // | 1234567 | |
| final String cellFormat = "%" + (minColumnSize + 2) + "s "; // | 12345 | |
| final String diagFormat = " %" + (minColumnSize + 2) + "s"; |
| for (int i = initialIndex; i < tagset.size(); i++) { |
| sb.append(String.format(headerFormat, |
| generateAlphaLabel(i - initialIndex).trim())); |
| } |
| sb.append("| Accuracy | <-- classified as\n"); |
| for (int i = initialIndex; i < data.length; i++) { |
| int j = initialIndex; |
| for (; j < data[i].length - 1; j++) { |
| if (i == j) { |
| String val = "<" + matrix[i][j] + ">"; |
| sb.append(String.format(diagFormat, val)); |
| } else { |
| sb.append(String.format(cellFormat, matrix[i][j])); |
| } |
| } |
| sb.append( |
| String.format("| %-6s | %3s = ", matrix[i][j], |
| generateAlphaLabel(i - initialIndex))).append(tags[i]); |
| sb.append("\n"); |
| } |
| return sb.toString(); |
| } |
| |
| protected void printGeneralStatistics() { |
| printHeader("Evaluation summary"); |
| printStream.append( |
| String.format("%21s: %6s", "Number of sentences", |
| Long.toString(getNumberOfSentences()))).append("\n"); |
| printStream.append( |
| String.format("%21s: %6s", "Min sentence size", getMinSentenceSize())) |
| .append("\n"); |
| printStream.append( |
| String.format("%21s: %6s", "Max sentence size", getMaxSentenceSize())) |
| .append("\n"); |
| printStream.append( |
| String.format("%21s: %6s", "Average sentence size", |
| MessageFormat.format("{0,number,#.##}", getAverageSentenceSize()))) |
| .append("\n"); |
| printStream.append( |
| String.format("%21s: %6s", "Tags count", getNumberOfTags())).append( |
| "\n"); |
| printStream.append( |
| String.format("%21s: %6s", "Accuracy", |
| MessageFormat.format("{0,number,#.##%}", getAccuracy()))).append( |
| "\n"); |
| printFooter("Evaluation Corpus Statistics"); |
| } |
| |
| protected void printTokenOcurrenciesRank() { |
| printHeader("Most frequent tokens"); |
| |
| SortedSet<String> toks = getTokensOrderedByFrequency(); |
| final int maxLines = 20; |
| |
| int maxTokSize = 5; |
| |
| int count = 0; |
| Iterator<String> tokIterator = toks.iterator(); |
| while (tokIterator.hasNext() && count++ < maxLines) { |
| String tok = tokIterator.next(); |
| if (tok.length() > maxTokSize) { |
| maxTokSize = tok.length(); |
| } |
| } |
| |
| int tableSize = maxTokSize + 19; |
| String format = "| %3s | %6s | %" + maxTokSize + "s |"; |
| |
| printLine(tableSize); |
| printStream.append(String.format(format, "Pos", "Count", "Token")).append( |
| "\n"); |
| printLine(tableSize); |
| |
| // get the first 20 errors |
| count = 0; |
| tokIterator = toks.iterator(); |
| while (tokIterator.hasNext() && count++ < maxLines) { |
| String tok = tokIterator.next(); |
| int ocurrencies = getTokenFrequency(tok); |
| |
| printStream.append(String.format(format, count, ocurrencies, tok) |
| |
| ).append("\n"); |
| } |
| printLine(tableSize); |
| printFooter("Most frequent tokens"); |
| } |
| |
| protected void printTokenErrorRank() { |
| printHeader("Tokens with the highest number of errors"); |
| printStream.append("\n"); |
| |
| SortedSet<String> toks = getTokensOrderedByNumberOfErrors(); |
| int maxTokenSize = 5; |
| |
| int count = 0; |
| Iterator<String> tokIterator = toks.iterator(); |
| while (tokIterator.hasNext() && count++ < 20) { |
| String tok = tokIterator.next(); |
| if (tok.length() > maxTokenSize) { |
| maxTokenSize = tok.length(); |
| } |
| } |
| |
| int tableSize = 31 + maxTokenSize; |
| |
| String format = "| %" + maxTokenSize + "s | %6s | %5s | %7s |\n"; |
| |
| printLine(tableSize); |
| printStream.append(String.format(format, "Token", "Errors", "Count", |
| "% Err")); |
| printLine(tableSize); |
| |
| // get the first 20 errors |
| count = 0; |
| tokIterator = toks.iterator(); |
| while (tokIterator.hasNext() && count++ < 20) { |
| String tok = tokIterator.next(); |
| int ocurrencies = getTokenFrequency(tok); |
| int errors = getTokenErrors(tok); |
| String rate = MessageFormat.format("{0,number,#.##%}", (double) errors |
| / ocurrencies); |
| |
| printStream.append(String.format(format, tok, errors, ocurrencies, rate) |
| |
| ); |
| } |
| printLine(tableSize); |
| printFooter("Tokens with the highest number of errors"); |
| } |
| |
| protected void printTagsErrorRank() { |
| printHeader("Detailed Accuracy By Tag"); |
| SortedSet<String> tags = getTagsOrderedByErrors(); |
| printStream.append("\n"); |
| |
| int maxTagSize = 3; |
| |
| for (String t : tags) { |
| if (t.length() > maxTagSize) { |
| maxTagSize = t.length(); |
| } |
| } |
| |
| int tableSize = 65 + maxTagSize; |
| |
| String headerFormat = "| %" + maxTagSize |
| + "s | %6s | %6s | %7s | %9s | %6s | %9s |\n"; |
| String format = "| %" + maxTagSize |
| + "s | %6s | %6s | %-7s | %-9s | %-6s | %-9s |\n"; |
| |
| printLine(tableSize); |
| printStream.append(String.format(headerFormat, "Tag", "Errors", "Count", |
| "% Err", "Precision", "Recall", "F-Measure")); |
| printLine(tableSize); |
| |
| for (String tag : tags) { |
| int ocurrencies = getTagFrequency(tag); |
| int errors = getTagErrors(tag); |
| String rate = MessageFormat.format("{0,number,#.###}", (double) errors |
| / ocurrencies); |
| |
| double p = getTagPrecision(tag); |
| double r = getTagRecall(tag); |
| double f = getTagFMeasure(tag); |
| |
| printStream.append(String.format(format, tag, errors, ocurrencies, rate, |
| MessageFormat.format("{0,number,#.###}", p > 0 ? p : 0), |
| MessageFormat.format("{0,number,#.###}", r > 0 ? r : 0), |
| MessageFormat.format("{0,number,#.###}", f > 0 ? f : 0)) |
| |
| ); |
| } |
| printLine(tableSize); |
| |
| printFooter("Tags with the highest number of errors"); |
| } |
| |
| protected void printGeneralConfusionTable() { |
| printHeader("Confusion matrix"); |
| |
| SortedSet<String> labels = getConfusionMatrixTagset(); |
| |
| double[][] confusionMatrix = getConfusionMatrix(); |
| |
| printStream.append("\nTags with 100% accuracy: "); |
| int line = 0; |
| for (String label : labels) { |
| if (confusionMatrix[line][confusionMatrix[0].length - 1] == 1) { |
| printStream.append(label).append(" (") |
| .append(Integer.toString((int) confusionMatrix[line][line])) |
| .append(") "); |
| } |
| line++; |
| } |
| |
| printStream.append("\n\n"); |
| |
| printStream.append(matrixToString(labels, confusionMatrix, true)); |
| |
| printFooter("Confusion matrix"); |
| } |
| |
| protected void printDetailedConfusionMatrix() { |
| printHeader("Confusion matrix for tokens"); |
| printStream.append(" sorted by number of errors\n"); |
| SortedSet<String> toks = getTokensOrderedByNumberOfErrors(); |
| |
| for (String t : toks) { |
| double acc = getTokenAccuracy(t); |
| if (acc < 1) { |
| printStream |
| .append("\n[") |
| .append(t) |
| .append("]\n") |
| .append( |
| String.format("%12s: %-8s", "Accuracy", |
| MessageFormat.format("{0,number,#.##%}", acc))) |
| .append("\n"); |
| printStream.append( |
| String.format("%12s: %-8s", "Ocurrencies", |
| Integer.toString(getTokenFrequency(t)))).append("\n"); |
| printStream.append( |
| String.format("%12s: %-8s", "Errors", |
| Integer.toString(getTokenErrors(t)))).append("\n"); |
| |
| SortedSet<String> labels = getConfusionMatrixTagset(t); |
| |
| double[][] confusionMatrix = getConfusionMatrix(t); |
| |
| printStream.append(matrixToString(labels, confusionMatrix, false)); |
| } |
| } |
| printFooter("Confusion matrix for tokens"); |
| } |
| |
| /** Auxiliary method that prints a emphasised report header */ |
| private void printHeader(String text) { |
| printStream.append("=== ").append(text).append(" ===\n"); |
| } |
| |
| /** Auxiliary method that prints a marker to the end of a report */ |
| private void printFooter(String text) { |
| printStream.append("\n<-end> ").append(text).append("\n\n"); |
| } |
| |
| /** Auxiliary method that prints a horizontal line of a given size */ |
| private void printLine(int size) { |
| for (int i = 0; i < size; i++) { |
| printStream.append("-"); |
| } |
| printStream.append("\n"); |
| } |
| |
| /** |
| * A comparator that sorts the confusion matrix labels according to the |
| * accuracy of each line |
| */ |
| public static class MatrixLabelComparator implements Comparator<String> { |
| |
| private Map<String, ConfusionMatrixLine> confusionMatrix; |
| |
| public MatrixLabelComparator(Map<String, ConfusionMatrixLine> confusionMatrix) { |
| this.confusionMatrix = confusionMatrix; |
| } |
| |
| public int compare(String o1, String o2) { |
| if (o1.equals(o2)) { |
| return 0; |
| } |
| ConfusionMatrixLine t1 = confusionMatrix.get(o1); |
| ConfusionMatrixLine t2 = confusionMatrix.get(o2); |
| if (t1 == null || t2 == null) { |
| if (t1 == null) { |
| return 1; |
| } else { |
| return -1; |
| } |
| } |
| double r1 = t1.getAccuracy(); |
| double r2 = t2.getAccuracy(); |
| if (r1 == r2) { |
| return o1.compareTo(o2); |
| } |
| if (r2 > r1) { |
| return 1; |
| } |
| return -1; |
| } |
| } |
| |
| public static class GroupedMatrixLabelComparator implements Comparator<String> { |
| |
| private final HashMap<String, Double> categoryAccuracy; |
| private Map<String, ConfusionMatrixLine> confusionMatrix; |
| |
| public GroupedMatrixLabelComparator(Map<String, ConfusionMatrixLine> confusionMatrix) { |
| this.confusionMatrix = confusionMatrix; |
| this.categoryAccuracy = new HashMap<>(); |
| |
| // compute grouped categories |
| for (Entry<String, ConfusionMatrixLine> entry : confusionMatrix.entrySet()) { |
| final String key = entry.getKey(); |
| final ConfusionMatrixLine confusionMatrixLine = entry.getValue(); |
| final String category; |
| if (key.contains("-")) { |
| category = key.split("-")[0]; |
| } else { |
| category = key; |
| } |
| double currentAccuracy = categoryAccuracy.getOrDefault(category, 0.0d); |
| categoryAccuracy.put(category, currentAccuracy + confusionMatrixLine.getAccuracy()); |
| } |
| } |
| |
| public int compare(String o1, String o2) { |
| if (o1.equals(o2)) { |
| return 0; |
| } |
| String c1 = o1; |
| String c2 = o2; |
| |
| if (o1.contains("-")) { |
| c1 = o1.split("-")[0]; |
| } |
| if (o2.contains("-")) { |
| c2 = o2.split("-")[0]; |
| } |
| |
| if (c1.equals(c2)) { // same category - sort by confusion matrix |
| |
| ConfusionMatrixLine t1 = confusionMatrix.get(o1); |
| ConfusionMatrixLine t2 = confusionMatrix.get(o2); |
| if (t1 == null || t2 == null) { |
| if (t1 == null) { |
| return 1; |
| } else { |
| return -1; |
| } |
| } |
| double r1 = t1.getAccuracy(); |
| double r2 = t2.getAccuracy(); |
| if (r1 == r2) { |
| return o1.compareTo(o2); |
| } |
| if (r2 > r1) { |
| return 1; |
| } |
| return -1; |
| } else { // different category - sort by category |
| Double t1 = categoryAccuracy.get(c1); |
| Double t2 = categoryAccuracy.get(c2); |
| if (t1 == null || t2 == null) { |
| if (t1 == null) { |
| return 1; |
| } else { |
| return -1; |
| } |
| } |
| if (t1.equals(t2)) { |
| return o1.compareTo(o2); |
| } |
| if (t2 > t1) { |
| return 1; |
| } |
| return -1; |
| } |
| } |
| } |
| |
| public Comparator<String> getMatrixLabelComparator(Map<String, ConfusionMatrixLine> confusionMatrix) { |
| return new MatrixLabelComparator(confusionMatrix); |
| } |
| |
| public static class SimpleLabelComparator implements Comparator<String> { |
| |
| private Map<String, Counter> map; |
| |
| public SimpleLabelComparator(Map<String, Counter> map) { |
| this.map = map; |
| } |
| |
| @Override |
| public int compare(String o1, String o2) { |
| if (o1.equals(o2)) { |
| return 0; |
| } |
| int e1 = 0, e2 = 0; |
| if (map.containsKey(o1)) |
| e1 = map.get(o1).value(); |
| if (map.containsKey(o2)) |
| e2 = map.get(o2).value(); |
| if (e1 == e2) { |
| return o1.compareTo(o2); |
| } |
| return e2 - e1; |
| } |
| } |
| |
| public Comparator<String> getLabelComparator(Map<String, Counter> map) { |
| return new SimpleLabelComparator(map); |
| } |
| |
| public static class GroupedLabelComparator implements Comparator<String> { |
| |
| private final HashMap<String, Integer> categoryCounter; |
| private Map<String, Counter> labelCounter; |
| |
| public GroupedLabelComparator(Map<String, Counter> map) { |
| this.labelCounter = map; |
| this.categoryCounter = new HashMap<>(); |
| |
| // compute grouped categories |
| for (Entry<String, Counter> entry : labelCounter.entrySet()) { |
| final String key = entry.getKey(); |
| final Counter value = entry.getValue(); |
| final String category; |
| if (key.contains("-")) { |
| category = key.split("-")[0]; |
| } else { |
| category = key; |
| } |
| int currentCount = categoryCounter.getOrDefault(category, 0); |
| categoryCounter.put(category, currentCount + value.value()); |
| } |
| } |
| |
| public int compare(String o1, String o2) { |
| if (o1.equals(o2)) { |
| return 0; |
| } |
| String c1 = o1; |
| String c2 = o2; |
| |
| if (o1.contains("-")) { |
| c1 = o1.split("-")[0]; |
| } |
| if (o2.contains("-")) { |
| c2 = o2.split("-")[0]; |
| } |
| |
| if (c1.equals(c2)) { // same category - sort by confusion matrix |
| |
| Counter t1 = labelCounter.get(o1); |
| Counter t2 = labelCounter.get(o2); |
| if (t1 == null || t2 == null) { |
| if (t1 == null) { |
| return 1; |
| } else { |
| return -1; |
| } |
| } |
| int r1 = t1.value(); |
| int r2 = t2.value(); |
| if (r1 == r2) { |
| return o1.compareTo(o2); |
| } |
| if (r2 > r1) { |
| return 1; |
| } |
| return -1; |
| } else { // different category - sort by category |
| Integer t1 = categoryCounter.get(c1); |
| Integer t2 = categoryCounter.get(c2); |
| if (t1 == null || t2 == null) { |
| if (t1 == null) { |
| return 1; |
| } else { |
| return -1; |
| } |
| } |
| if (t1.equals(t2)) { |
| return o1.compareTo(o2); |
| } |
| if (t2 > t1) { |
| return 1; |
| } |
| return -1; |
| } |
| } |
| } |
| |
| /** |
| * Represents a line in the confusion table. |
| */ |
| public static class ConfusionMatrixLine { |
| |
| private Map<String, Counter> line = new HashMap<>(); |
| private String ref; |
| private int total = 0; |
| private int correct = 0; |
| private double acc = -1; |
| |
| /** |
| * Creates a new {@link ConfusionMatrixLine} |
| * |
| * @param ref |
| * the reference column |
| */ |
| private ConfusionMatrixLine(String ref) { |
| this.ref = ref; |
| } |
| |
| /** |
| * Increments the counter for the given column and updates the statistics. |
| * |
| * @param column |
| * the column to be incremented |
| */ |
| private void increment(String column) { |
| total++; |
| if (column.equals(ref)) |
| correct++; |
| if (!line.containsKey(column)) { |
| line.put(column, new Counter()); |
| } |
| line.get(column).increment(); |
| } |
| |
| /** |
| * Gets the calculated accuracy of this element |
| * |
| * @return the accuracy |
| */ |
| public double getAccuracy() { |
| // we save the accuracy because it is frequently used by the comparator |
| if (StrictMath.abs(acc - 1.0d) < 0.0000000001) { |
| if (total == 0) |
| acc = 0.0d; |
| acc = (double) correct / (double) total; |
| } |
| return acc; |
| } |
| |
| /** |
| * Gets the value given a column |
| * |
| * @param column |
| * the column |
| * @return the counter value |
| */ |
| public int getValue(String column) { |
| Counter c = line.get(column); |
| if (c == null) |
| return 0; |
| return c.value(); |
| } |
| } |
| |
| /** |
| * Implements a simple counter |
| */ |
| public static class Counter { |
| private int c = 0; |
| |
| private void increment() { |
| c++; |
| } |
| |
| public int value() { |
| return c; |
| } |
| } |
| |
| public class Stats { |
| |
| // general statistics |
| private final Mean accuracy = new Mean(); |
| private final Mean averageSentenceLength = new Mean(); |
| // token statistics |
| private final Map<String, Mean> tokAccuracies = new HashMap<>(); |
| private final Map<String, Counter> tokOcurrencies = new HashMap<>(); |
| private final Map<String, Counter> tokErrors = new HashMap<>(); |
| // tag statistics |
| private final Map<String, Counter> tagOcurrencies = new HashMap<>(); |
| private final Map<String, Counter> tagErrors = new HashMap<>(); |
| private final Map<String, FMeasure> tagFMeasure = new HashMap<>(); |
| // represents a Confusion Matrix that aggregates all tokens |
| private final Map<String, ConfusionMatrixLine> generalConfusionMatrix = new HashMap<>(); |
| // represents a set of Confusion Matrix for each token |
| private final Map<String, Map<String, ConfusionMatrixLine>> tokenConfusionMatrix = new HashMap<>(); |
| private int minimalSentenceLength = Integer.MAX_VALUE; |
| private int maximumSentenceLength = Integer.MIN_VALUE; |
| |
| public void add(String[] toks, String[] refs, String[] preds) { |
| int length = toks.length; |
| averageSentenceLength.add(length); |
| |
| if (minimalSentenceLength > length) { |
| minimalSentenceLength = length; |
| } |
| if (maximumSentenceLength < length) { |
| maximumSentenceLength = length; |
| } |
| |
| updateTagFMeasure(refs, preds); |
| |
| for (int i = 0; i < toks.length; i++) { |
| commit(toks[i], refs[i], preds[i]); |
| } |
| } |
| |
| public void add(int length, String ref, String pred) { |
| |
| averageSentenceLength.add(length); |
| |
| if (minimalSentenceLength > length) { |
| minimalSentenceLength = length; |
| } |
| if (maximumSentenceLength < length) { |
| maximumSentenceLength = length; |
| } |
| |
| // String[] toks = reference.getSentence(); |
| String[] refs = { ref }; |
| String[] preds = { pred }; |
| |
| updateTagFMeasure(refs, preds); |
| |
| commit("", ref, pred); |
| } |
| |
| public void add(String[] text, String ref, String pred) { |
| int length = text.length; |
| this.add(length, ref, pred); |
| } |
| |
| public void add(CharSequence text, String ref, String pred) { |
| int length = text.length(); |
| this.add(length, ref, pred); |
| } |
| |
| |
| /** |
| * Includes a new evaluation data |
| * |
| * @param tok |
| * the evaluated token |
| * @param ref |
| * the reference pos tag |
| * @param pred |
| * the predicted pos tag |
| */ |
| private void commit(String tok, String ref, String pred) { |
| // token stats |
| if (!tokAccuracies.containsKey(tok)) { |
| tokAccuracies.put(tok, new Mean()); |
| tokOcurrencies.put(tok, new Counter()); |
| tokErrors.put(tok, new Counter()); |
| } |
| tokOcurrencies.get(tok).increment(); |
| |
| // tag stats |
| if (!tagOcurrencies.containsKey(ref)) { |
| tagOcurrencies.put(ref, new Counter()); |
| tagErrors.put(ref, new Counter()); |
| } |
| tagOcurrencies.get(ref).increment(); |
| |
| // updates general, token and tag error stats |
| if (ref.equals(pred)) { |
| tokAccuracies.get(tok).add(1); |
| accuracy.add(1); |
| } else { |
| tokAccuracies.get(tok).add(0); |
| tokErrors.get(tok).increment(); |
| tagErrors.get(ref).increment(); |
| accuracy.add(0); |
| } |
| |
| // populate confusion matrixes |
| if (!generalConfusionMatrix.containsKey(ref)) { |
| generalConfusionMatrix.put(ref, new ConfusionMatrixLine(ref)); |
| } |
| generalConfusionMatrix.get(ref).increment(pred); |
| |
| if (!tokenConfusionMatrix.containsKey(tok)) { |
| tokenConfusionMatrix.put(tok, new HashMap<>()); |
| } |
| if (!tokenConfusionMatrix.get(tok).containsKey(ref)) { |
| tokenConfusionMatrix.get(tok).put(ref, new ConfusionMatrixLine(ref)); |
| } |
| tokenConfusionMatrix.get(tok).get(ref).increment(pred); |
| } |
| |
| private void updateTagFMeasure(String[] refs, String[] preds) { |
| // create a set with all tags |
| Set<String> tags = new HashSet<>(Arrays.asList(refs)); |
| tags.addAll(Arrays.asList(preds)); |
| |
| // create samples for each tag |
| for (String tag : tags) { |
| List<Span> reference = new ArrayList<>(); |
| List<Span> prediction = new ArrayList<>(); |
| for (int i = 0; i < refs.length; i++) { |
| if (refs[i].equals(tag)) { |
| reference.add(new Span(i, i + 1)); |
| } |
| if (preds[i].equals(tag)) { |
| prediction.add(new Span(i, i + 1)); |
| } |
| } |
| if (!this.tagFMeasure.containsKey(tag)) { |
| this.tagFMeasure.put(tag, new FMeasure()); |
| } |
| // populate the fmeasure |
| this.tagFMeasure.get(tag).updateScores( |
| reference.toArray(new Span[reference.size()]), |
| prediction.toArray(new Span[prediction.size()])); |
| } |
| } |
| |
| private double getAccuracy() { |
| return accuracy.mean(); |
| } |
| |
| private int getNumberOfTags() { |
| return this.tagOcurrencies.keySet().size(); |
| } |
| |
| private long getNumberOfSentences() { |
| return this.averageSentenceLength.count(); |
| } |
| |
| private double getAverageSentenceSize() { |
| return this.averageSentenceLength.mean(); |
| } |
| |
| private int getMinSentenceSize() { |
| return this.minimalSentenceLength; |
| } |
| |
| private int getMaxSentenceSize() { |
| return this.maximumSentenceLength; |
| } |
| |
| private double getTokenAccuracy(String token) { |
| return tokAccuracies.get(token).mean(); |
| } |
| |
| private int getTokenErrors(String token) { |
| return tokErrors.get(token).value(); |
| } |
| |
| private int getTokenFrequency(String token) { |
| return tokOcurrencies.get(token).value(); |
| } |
| |
| private SortedSet<String> getTokensOrderedByFrequency() { |
| SortedSet<String> toks = new TreeSet<>(new SimpleLabelComparator(tokOcurrencies)); |
| toks.addAll(tokOcurrencies.keySet()); |
| return Collections.unmodifiableSortedSet(toks); |
| } |
| |
| private SortedSet<String> getTokensOrderedByNumberOfErrors() { |
| SortedSet<String> toks = new TreeSet<>(new SimpleLabelComparator(tokErrors)); |
| toks.addAll(tokErrors.keySet()); |
| return toks; |
| } |
| |
| private int getTagFrequency(String tag) { |
| return tagOcurrencies.get(tag).value(); |
| } |
| |
| private int getTagErrors(String tag) { |
| return tagErrors.get(tag).value(); |
| } |
| |
| private double getTagFMeasure(String tag) { |
| return tagFMeasure.get(tag).getFMeasure(); |
| } |
| |
| private double getTagRecall(String tag) { |
| return tagFMeasure.get(tag).getRecallScore(); |
| } |
| |
| private double getTagPrecision(String tag) { |
| return tagFMeasure.get(tag).getPrecisionScore(); |
| } |
| |
| private SortedSet<String> getTagsOrderedByErrors() { |
| SortedSet<String> tags = new TreeSet<>(getLabelComparator(tagErrors)); |
| tags.addAll(tagErrors.keySet()); |
| return Collections.unmodifiableSortedSet(tags); |
| } |
| |
| private SortedSet<String> getConfusionMatrixTagset() { |
| return getConfusionMatrixTagset(generalConfusionMatrix); |
| } |
| |
| private double[][] getConfusionMatrix() { |
| return createConfusionMatrix(getConfusionMatrixTagset(), |
| generalConfusionMatrix); |
| } |
| |
| private SortedSet<String> getConfusionMatrixTagset(String token) { |
| return getConfusionMatrixTagset(tokenConfusionMatrix.get(token)); |
| } |
| |
| private double[][] getConfusionMatrix(String token) { |
| return createConfusionMatrix(getConfusionMatrixTagset(token), |
| tokenConfusionMatrix.get(token)); |
| } |
| |
| /** |
| * Creates a matrix with N lines and N + 1 columns with the data from |
| * confusion matrix. The last column is the accuracy. |
| */ |
| private double[][] createConfusionMatrix(SortedSet<String> tagset, |
| Map<String, ConfusionMatrixLine> data) { |
| int size = tagset.size(); |
| double[][] matrix = new double[size][size + 1]; |
| int line = 0; |
| for (String ref : tagset) { |
| int column = 0; |
| for (String pred : tagset) { |
| matrix[line][column] = data.get(ref) != null ? data |
| .get(ref).getValue(pred) : 0; |
| column++; |
| } |
| // set accuracy |
| matrix[line][column] = data.get(ref) != null ? data.get(ref).getAccuracy() : 0; |
| line++; |
| } |
| |
| return matrix; |
| } |
| |
| private SortedSet<String> getConfusionMatrixTagset( |
| Map<String, ConfusionMatrixLine> data) { |
| SortedSet<String> tags = new TreeSet<>(getMatrixLabelComparator(data)); |
| tags.addAll(data.keySet()); |
| List<String> col = new LinkedList<>(); |
| for (String t : tags) { |
| col.addAll(data.get(t).line.keySet()); |
| } |
| tags.addAll(col); |
| return Collections.unmodifiableSortedSet(tags); |
| } |
| } |
| } |