| package org.apache.samoa.evaluation; |
| |
| /* |
| * #%L |
| * SAMOA |
| * %% |
| * Copyright (C) 2014 - 2016 Apache Software Foundation |
| * %% |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * #L% |
| */ |
| |
| |
| import org.apache.samoa.instances.Instance; |
| import org.apache.samoa.instances.Utils; |
| import org.apache.samoa.moa.AbstractMOAObject; |
| import org.apache.samoa.moa.core.Measurement; |
| |
| import java.util.Collections; |
| import java.util.List; |
| import java.util.Vector; |
| |
| /** |
| * Created by Edi Bice (edi.bice gmail com) on 2/22/2016. |
| */ |
| public class F1ClassificationPerformanceEvaluator extends AbstractMOAObject implements |
| ClassificationPerformanceEvaluator { |
| |
| private static final long serialVersionUID = 1L; |
| protected int numClasses = -1; |
| |
| protected long[] support; |
| protected long[] truePos; |
| protected long[] falsePos; |
| protected long[] trueNeg; |
| protected long[] falseNeg; |
| |
| @Override |
| public void reset() { |
| reset(this.numClasses); |
| } |
| |
| public void reset(int numClasses) { |
| this.numClasses = numClasses; |
| this.support = new long[numClasses]; |
| this.truePos = new long[numClasses]; |
| this.falsePos = new long[numClasses]; |
| this.trueNeg = new long[numClasses]; |
| this.falseNeg = new long[numClasses]; |
| for (int i = 0; i < this.numClasses; i++) { |
| this.support[i] = 0; |
| this.truePos[i] = 0; |
| this.falsePos[i] = 0; |
| this.trueNeg[i] = 0; |
| this.falseNeg[i] = 0; |
| } |
| } |
| |
| @Override |
| public void addResult(Instance inst, double[] classVotes) { |
| if (numClasses==-1) reset(inst.numClasses()); |
| int trueClass = (int) inst.classValue(); |
| this.support[trueClass] += 1; |
| int predictedClass = Utils.maxIndex(classVotes); |
| if (predictedClass == trueClass) { |
| this.truePos[trueClass] += 1; |
| for (int i = 0; i < this.numClasses; i++) { |
| if (i!=predictedClass) this.trueNeg[i] += 1; |
| } |
| } else { |
| this.falsePos[predictedClass] += 1; |
| this.falseNeg[trueClass] += 1; |
| for (int i = 0; i < this.numClasses; i++) { |
| if (!(i==predictedClass || i==trueClass)) this.trueNeg[i] += 1; |
| } |
| } |
| } |
| |
| @Override |
| public Measurement[] getPerformanceMeasurements() { |
| List<Measurement> measurements = new Vector<>(); |
| Collections.addAll(measurements, getSupportMeasurements()); |
| Collections.addAll(measurements, getPrecisionMeasurements()); |
| Collections.addAll(measurements, getRecallMeasurements()); |
| Collections.addAll(measurements, getF1Measurements()); |
| return measurements.toArray(new Measurement[measurements.size()]); |
| } |
| |
| private Measurement[] getSupportMeasurements() { |
| Measurement[] measurements = new Measurement[this.numClasses]; |
| for (int i = 0; i < this.numClasses; i++) { |
| String ml = String.format("class %s support", i); |
| measurements[i] = new Measurement(ml, this.support[i]); |
| } |
| return measurements; |
| } |
| |
| private Measurement[] getPrecisionMeasurements() { |
| Measurement[] measurements = new Measurement[this.numClasses]; |
| for (int i = 0; i < this.numClasses; i++) { |
| String ml = String.format("class %s precision", i); |
| measurements[i] = new Measurement(ml, getPrecision(i), 10); |
| } |
| return measurements; |
| } |
| |
| private Measurement[] getRecallMeasurements() { |
| Measurement[] measurements = new Measurement[this.numClasses]; |
| for (int i = 0; i < this.numClasses; i++) { |
| String ml = String.format("class %s recall", i); |
| measurements[i] = new Measurement(ml, getRecall(i), 10); |
| } |
| return measurements; |
| } |
| |
| private Measurement[] getF1Measurements() { |
| Measurement[] measurements = new Measurement[this.numClasses]; |
| for (int i = 0; i < this.numClasses; i++) { |
| String ml = String.format("class %s f1-score", i); |
| measurements[i] = new Measurement(ml, getF1Score(i), 10); |
| } |
| return measurements; |
| } |
| |
| @Override |
| public void getDescription(StringBuilder sb, int indent) { |
| Measurement.getMeasurementsDescription(getSupportMeasurements(), sb, indent); |
| Measurement.getMeasurementsDescription(getPrecisionMeasurements(), sb, indent); |
| Measurement.getMeasurementsDescription(getRecallMeasurements(), sb, indent); |
| Measurement.getMeasurementsDescription(getF1Measurements(), sb, indent); |
| } |
| |
| private double getPrecision(int classIndex) { |
| return (double) this.truePos[classIndex] / (this.truePos[classIndex] + this.falsePos[classIndex]); |
| } |
| |
| private double getRecall(int classIndex) { |
| return (double) this.truePos[classIndex] / (this.truePos[classIndex] + this.falseNeg[classIndex]); |
| } |
| |
| private double getF1Score(int classIndex) { |
| double precision = getPrecision(classIndex); |
| double recall = getRecall(classIndex); |
| return 2 * (precision * recall) / (precision + recall); |
| } |
| |
| } |