blob: 24abe3e078687451fc093873053e3f85842193e6 [file] [log] [blame]
package org.apache.samoa.evaluation;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;
/**
* Classification evaluator that performs basic incremental evaluation.
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject implements
ClassificationPerformanceEvaluator {
private static final long serialVersionUID = 1L;
protected double weightObserved;
protected double weightCorrect;
protected double[] columnKappa;
protected double[] rowKappa;
protected int numClasses;
private double weightCorrectNoChangeClassifier;
private int lastSeenClass;
@Override
public void reset() {
reset(this.numClasses);
}
public void reset(int numClasses) {
this.numClasses = numClasses;
this.rowKappa = new double[numClasses];
this.columnKappa = new double[numClasses];
for (int i = 0; i < this.numClasses; i++) {
this.rowKappa[i] = 0.0;
this.columnKappa[i] = 0.0;
}
this.weightObserved = 0.0;
this.weightCorrect = 0.0;
this.weightCorrectNoChangeClassifier = 0.0;
this.lastSeenClass = 0;
}
@Override
public void addResult(Instance inst, double[] classVotes) {
double weight = inst.weight();
int trueClass = (int) inst.classValue();
if (weight > 0.0) {
if (this.weightObserved == 0) {
reset(inst.numClasses());
}
this.weightObserved += weight;
int predictedClass = Utils.maxIndex(classVotes);
if (predictedClass == trueClass) {
this.weightCorrect += weight;
}
if (rowKappa.length > 0) {
this.rowKappa[predictedClass] += weight;
}
if (columnKappa.length > 0) {
this.columnKappa[trueClass] += weight;
}
}
if (this.lastSeenClass == trueClass) {
this.weightCorrectNoChangeClassifier += weight;
}
this.lastSeenClass = trueClass;
}
@Override
public Measurement[] getPerformanceMeasurements() {
return new Measurement[] {
new Measurement("classified instances",
getTotalWeightObserved()),
new Measurement("classifications correct (percent)",
getFractionCorrectlyClassified() * 100.0),
new Measurement("Kappa Statistic (percent)",
getKappaStatistic() * 100.0),
new Measurement("Kappa Temporal Statistic (percent)",
getKappaTemporalStatistic() * 100.0)
};
}
public double getTotalWeightObserved() {
return this.weightObserved;
}
public double getFractionCorrectlyClassified() {
return this.weightObserved > 0.0 ? this.weightCorrect
/ this.weightObserved : 0.0;
}
public double getFractionIncorrectlyClassified() {
return 1.0 - getFractionCorrectlyClassified();
}
public double getKappaStatistic() {
if (this.weightObserved > 0.0) {
double p0 = getFractionCorrectlyClassified();
double pc = 0.0;
for (int i = 0; i < this.numClasses; i++) {
pc += (this.rowKappa[i] / this.weightObserved)
* (this.columnKappa[i] / this.weightObserved);
}
return (p0 - pc) / (1.0 - pc);
} else {
return 0;
}
}
public double getKappaTemporalStatistic() {
if (this.weightObserved > 0.0) {
double p0 = this.weightCorrect / this.weightObserved;
double pc = this.weightCorrectNoChangeClassifier / this.weightObserved;
return (p0 - pc) / (1.0 - pc);
} else {
return 0;
}
}
@Override
public void getDescription(StringBuilder sb, int indent) {
Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
sb, indent);
}
}