SAMOA-68: Saving true and predicted labels to file
Fix #61
diff --git a/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java b/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java
index 80d9449..7e0ca48 100644
--- a/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java
+++ b/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java
@@ -35,6 +35,7 @@
.samplingSize(20_000)
.evaluationInstances(200_000)
.classifiedInstances(200_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(55f)
.kappaStat(0f)
.kappaTempStat(0f)
@@ -54,6 +55,7 @@
.samplingSize(20_000)
.evaluationInstances(180_000)
.classifiedInstances(190_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(60f)
.kappaStat(0f)
.kappaTempStat(0f)
diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java
index 24abe3e..a77831a 100644
--- a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java
@@ -1,5 +1,10 @@
package org.apache.samoa.evaluation;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.samoa.instances.Attribute;
+
/*
* #%L
* SAMOA
@@ -24,6 +29,7 @@
import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;
+import org.apache.samoa.moa.core.Vote;
/**
* Classification evaluator that performs basic incremental evaluation.
@@ -32,11 +38,23 @@
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
-public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject implements
- ClassificationPerformanceEvaluator {
+public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject
+ implements ClassificationPerformanceEvaluator {
private static final long serialVersionUID = 1L;
+ // the number of decimal places placed for double values in prediction file
+ // the value of 10 is used since some votes can be relatively small
+ public static final int DECIMAL_PLACES = 10;
+
+ // the vote value to be used when a classifier made no vote for the class at
+ // all
+ public static final int NO_VOTE_FOR_CLASS = 0;
+
+ // recent vote objects i.e. predicted, true classes and votes for individual
+ // classes
+ protected Vote[] votes;
+
protected double weightObserved;
protected double weightCorrect;
@@ -49,11 +67,17 @@
private double weightCorrectNoChangeClassifier;
+ protected double[] classVotes;
+
private int lastSeenClass;
+ private String instanceIdentifier;
+
+ private Instance lastSeenInstance;
@Override
public void reset() {
reset(this.numClasses);
+ votes = null;
}
public void reset(int numClasses) {
@@ -68,10 +92,11 @@
this.weightCorrect = 0.0;
this.weightCorrectNoChangeClassifier = 0.0;
this.lastSeenClass = 0;
+ votes = null;
}
@Override
- public void addResult(Instance inst, double[] classVotes) {
+ public void addResult(Instance inst, double[] classVotes, String instanceIdentifier) {
double weight = inst.weight();
int trueClass = (int) inst.classValue();
if (weight > 0.0) {
@@ -94,20 +119,60 @@
this.weightCorrectNoChangeClassifier += weight;
}
this.lastSeenClass = trueClass;
+ this.lastSeenInstance = inst;
+ this.instanceIdentifier = instanceIdentifier;
+ this.classVotes = classVotes;
}
@Override
public Measurement[] getPerformanceMeasurements() {
- return new Measurement[] {
- new Measurement("classified instances",
- getTotalWeightObserved()),
- new Measurement("classifications correct (percent)",
- getFractionCorrectlyClassified() * 100.0),
- new Measurement("Kappa Statistic (percent)",
- getKappaStatistic() * 100.0),
- new Measurement("Kappa Temporal Statistic (percent)",
- getKappaTemporalStatistic() * 100.0)
- };
+ return new Measurement[] { new Measurement("classified instances", getTotalWeightObserved()),
+ new Measurement("classifications correct (percent)", getFractionCorrectlyClassified() * 100.0),
+ new Measurement("Kappa Statistic (percent)", getKappaStatistic() * 100.0),
+ new Measurement("Kappa Temporal Statistic (percent)", getKappaTemporalStatistic() * 100.0) };
+
+ }
+
+ /**
+ * This method is used to retrieve predictions and votes (for classification only)
+ *
+ * @return String This returns an array of predictions and votes objects.
+ */
+ @Override
+ public Vote[] getPredictionVotes() {
+ Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute();
+ double trueValue = this.lastSeenInstance.classValue();
+ List<String> classAttributeValues = classAttribute.getAttributeValues();
+
+ int trueNominalIndex = (int) trueValue;
+ String trueNominalValue = classAttributeValues.get(trueNominalIndex);
+
+ // initialise votes first time they are supposed to be used
+ if (votes == null) {
+ this.votes = new Vote[classAttributeValues.size() + 3];
+ votes[0] = new Vote("instance number");
+ votes[1] = new Vote("true class value");
+ votes[2] = new Vote("predicted class value");
+
+ // create as many objects as the number of classes
+ for (int i = 0; i < classAttributeValues.size(); i++) {
+ votes[3 + i] = new Vote("votes_" + classAttributeValues.get(i));
+ }
+ }
+
+ // use/(re-use existing) vote objects
+ votes[0].setValue(this.instanceIdentifier);
+ votes[1].setValue(trueNominalValue);
+ votes[2].setValue(classAttributeValues.get(Utils.maxIndex(classVotes)));
+ for (int i = 0; i < classAttributeValues.size(); i++) {
+ if (i < classVotes.length) {
+ votes[3 + i].setValue(classVotes[i], this.DECIMAL_PLACES);
+ } else {
+ votes[3 + i].setValue(this.NO_VOTE_FOR_CLASS, 0);
+ }
+ }
+
+ return votes;
}
@@ -116,8 +181,7 @@
}
public double getFractionCorrectlyClassified() {
- return this.weightObserved > 0.0 ? this.weightCorrect
- / this.weightObserved : 0.0;
+ return this.weightObserved > 0.0 ? this.weightCorrect / this.weightObserved : 0.0;
}
public double getFractionIncorrectlyClassified() {
@@ -129,8 +193,7 @@
double p0 = getFractionCorrectlyClassified();
double pc = 0.0;
for (int i = 0; i < this.numClasses; i++) {
- pc += (this.rowKappa[i] / this.weightObserved)
- * (this.columnKappa[i] / this.weightObserved);
+ pc += (this.rowKappa[i] / this.weightObserved) * (this.columnKappa[i] / this.weightObserved);
}
return (p0 - pc) / (1.0 - pc);
} else {
@@ -151,7 +214,6 @@
@Override
public void getDescription(StringBuilder sb, int indent) {
- Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
- sb, indent);
+ Measurement.getMeasurementsDescription(getPerformanceMeasurements(), sb, indent);
}
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java
index ec48156..ab16904 100644
--- a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java
@@ -1,5 +1,9 @@
package org.apache.samoa.evaluation;
+import java.util.List;
+
+import org.apache.samoa.instances.Attribute;
+
/*
* #%L
* SAMOA
@@ -21,8 +25,10 @@
*/
import org.apache.samoa.instances.Instance;
+import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;
+import org.apache.samoa.moa.core.Vote;
/**
* Regression evaluator that performs basic incremental evaluation.
@@ -35,6 +41,10 @@
private static final long serialVersionUID = 1L;
+ // the number of decimal places placed for double values in prediction file
+ // the value of 10 is used since some predicted values can be relatively small
+ public static final int DECIMAL_PLACES = 10;
+
protected double weightObserved;
protected double squareError;
@@ -47,6 +57,10 @@
protected double averageTargetError;
+ private String instanceIdentifier;
+ private Instance lastSeenInstance;
+ private double lastPredictedValue;
+
@Override
public void reset() {
this.weightObserved = 0.0;
@@ -59,19 +73,23 @@
}
@Override
- public void addResult(Instance inst, double[] prediction) {
+ public void addResult(Instance inst, double[] prediction, String instanceIdentifier) {
double weight = inst.weight();
double classValue = inst.classValue();
if (weight > 0.0) {
if (prediction.length > 0) {
- double meanTarget = this.weightObserved != 0 ?
- this.sumTarget / this.weightObserved : 0.0;
+ double meanTarget = this.weightObserved != 0 ? this.sumTarget / this.weightObserved : 0.0;
this.squareError += (classValue - prediction[0]) * (classValue - prediction[0]);
this.averageError += Math.abs(classValue - prediction[0]);
this.squareTargetError += (classValue - meanTarget) * (classValue - meanTarget);
this.averageTargetError += Math.abs(classValue - meanTarget);
this.sumTarget += classValue;
this.weightObserved += weight;
+ this.lastPredictedValue = prediction[0];
+ this.lastSeenInstance = inst;
+ this.instanceIdentifier = instanceIdentifier;
+ } else {
+ this.lastPredictedValue = Double.NaN;
}
}
}
@@ -92,6 +110,22 @@
};
}
+ /**
+ * This method is used to retrieve predictions
+ *
+ * @return String This returns an array of predictions and votes objects.
+ */
+ @Override
+ public Vote[] getPredictionVotes() {
+ double trueValue = this.lastSeenInstance.classValue();
+ return new Vote[] {
+ new Vote("instance number",
+ this.instanceIdentifier),
+ new Vote("true value", trueValue, this.DECIMAL_PLACES),
+ new Vote("predicted value", this.lastPredictedValue, this.DECIMAL_PLACES)
+ };
+ }
+
public double getTotalWeightObserved() {
return this.weightObserved;
}
@@ -123,12 +157,10 @@
}
private double getRelativeMeanError() {
- return this.averageTargetError > 0 ?
- this.averageError / this.averageTargetError : 0.0;
+ return this.averageTargetError > 0 ? this.averageError / this.averageTargetError : 0.0;
}
private double getRelativeSquareError() {
- return Math.sqrt(this.squareTargetError > 0 ?
- this.squareError / this.squareTargetError : 0.0);
+ return Math.sqrt(this.squareTargetError > 0 ? this.squareError / this.squareTargetError : 0.0);
}
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java
index f282f0d..05d0a27 100644
--- a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java
+++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java
@@ -39,12 +39,11 @@
public class EvaluatorCVProcessor implements Processor {
/**
- *
- */
+ *
+ */
private static final long serialVersionUID = -2778051819116753612L;
- private static final Logger logger =
- LoggerFactory.getLogger(EvaluatorCVProcessor.class);
+ private static final Logger logger = LoggerFactory.getLogger(EvaluatorCVProcessor.class);
private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances";
@@ -90,7 +89,8 @@
addStatisticsForInstanceReceived(instanceIndex, result.getEvaluationIndex(), 1);
- evaluators[result.getEvaluationIndex()].addResult(result.getInstance(), result.getClassVotes());
+ evaluators[result.getEvaluationIndex()].addResult(result.getInstance(), result.getClassVotes(),
+ String.valueOf(instanceIndex));
if (hasAllVotesArrivedInstance(instanceIndex)) {
totalCount += 1;
@@ -110,8 +110,6 @@
}
}
-
-
return false;
}
@@ -122,6 +120,7 @@
int count = map.get(instanceIndex);
return (count == this.foldNumber);
}
+
protected void addStatisticsForInstanceReceived(int instanceIndex, int evaluationIndex, int add) {
if (this.mapCountsforInstanceReceived == null) {
this.mapCountsforInstanceReceived = new HashMap<>();
@@ -190,10 +189,10 @@
private void addMeasurement() {
List<Measurement> measurements = new Vector<>();
- measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount ));
+ measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount));
Measurement[] finalMeasurements = getEvaluationMeasurements(
- measurements.toArray(new Measurement[measurements.size()]), evaluators);
+ measurements.toArray(new Measurement[measurements.size()]), evaluators);
LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements);
learningCurve.insertEntry(learningEvaluation);
@@ -220,7 +219,7 @@
long experimentEnd = System.nanoTime();
long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS);
- logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount );
+ logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount);
if (immediateResultStream != null) {
immediateResultStream.println("# COMPLETED");
@@ -257,7 +256,7 @@
return this;
}
- public Builder foldNumber(int foldNumber){
+ public Builder foldNumber(int foldNumber) {
this.foldNumber = foldNumber;
return this;
}
@@ -267,7 +266,8 @@
}
}
- public Measurement[] getEvaluationMeasurements(Measurement[] modelMeasurements, PerformanceEvaluator[] subEvaluators) {
+ public Measurement[] getEvaluationMeasurements(Measurement[] modelMeasurements,
+ PerformanceEvaluator[] subEvaluators) {
List<Measurement> measurementList = new LinkedList<Measurement>();
if (modelMeasurements != null) {
measurementList.addAll(Arrays.asList(modelMeasurements));
@@ -280,7 +280,8 @@
subMeasurements.add(subEvaluator.getPerformanceMeasurements());
}
}
- Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][]));
+ Measurement[] avgMeasurements = Measurement
+ .averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][]));
measurementList.addAll(Arrays.asList(avgMeasurements));
}
return measurementList.toArray(new Measurement[measurementList.size()]);
diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java
index 6ec50dc..e78395a 100644
--- a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java
+++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java
@@ -33,6 +33,7 @@
import org.apache.samoa.core.Processor;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.moa.core.Measurement;
+import org.apache.samoa.moa.core.Vote;
import org.apache.samoa.moa.evaluation.LearningCurve;
import org.apache.samoa.moa.evaluation.LearningEvaluation;
import org.slf4j.Logger;
@@ -41,20 +42,23 @@
public class EvaluatorProcessor implements Processor {
/**
- *
- */
+ *
+ */
private static final long serialVersionUID = -2778051819116753612L;
- private static final Logger logger =
- LoggerFactory.getLogger(EvaluatorProcessor.class);
+ private static final Logger logger = LoggerFactory.getLogger(EvaluatorProcessor.class);
private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances";
private final PerformanceEvaluator evaluator;
private final int samplingFrequency;
private final File dumpFile;
+ private final File predictionFile;
+ private final int labelSamplingFrequency;
private transient PrintStream immediateResultStream = null;
+ private transient PrintStream immediatePredictionStream = null;
private transient boolean firstDump = true;
+ private transient boolean firstVoteDump = true;
private long totalCount = 0;
private long experimentStart = 0;
@@ -68,6 +72,8 @@
this.evaluator = builder.evaluator;
this.samplingFrequency = builder.samplingFrequency;
this.dumpFile = builder.dumpFile;
+ this.predictionFile = builder.predictionFile;
+ this.labelSamplingFrequency = builder.labelSamplingFrequency;
}
@Override
@@ -84,12 +90,18 @@
this.addMeasurement();
}
+ //adding a vote - true class value, predicted class value and for classification - votes
+ if ((immediatePredictionStream != null) && (totalCount > 0) && (totalCount % labelSamplingFrequency) == 0) {
+ this.addVote();
+ }
+
if (result.isLastEvent()) {
this.concludeMeasurement();
return true;
}
- evaluator.addResult(result.getInstance(), result.getClassVotes());
+ String instanceIndex = String.valueOf(result.getInstanceIndex());
+ evaluator.addResult(result.getInstance(), result.getClassVotes(), instanceIndex);
totalCount += 1;
if (totalCount == 1) {
@@ -125,7 +137,20 @@
}
}
+ if (this.predictionFile != null) {
+ try {
+ this.immediatePredictionStream = new PrintStream(new FileOutputStream(predictionFile), true);
+ } catch (FileNotFoundException e) {
+ this.immediatePredictionStream = null;
+ logger.error("File not found exception for {}:{}", this.predictionFile.getAbsolutePath(), e.toString());
+ } catch (Exception e) {
+ this.immediatePredictionStream = null;
+ logger.error("Exception when creating {}:{}", this.predictionFile.getAbsolutePath(), e.toString());
+ }
+ }
+
this.firstDump = true;
+ this.firstVoteDump = true;
}
@Override
@@ -179,6 +204,26 @@
}
}
+ /**
+ * This method is used to create one line of a text file containing predictions and votes (for classification only).
+ * In case, this is the first line a header line is also added
+ */
+ private void addVote() {
+ Vote[] finalVotes = evaluator.getPredictionVotes();
+ learningCurve.setVote(finalVotes);
+ logger.debug("evaluator id = {}", this.id);
+
+ if (immediatePredictionStream != null) {
+ if (firstVoteDump) {
+ immediatePredictionStream.println(learningCurve.voteHeaderToString());
+ firstVoteDump = false;
+ }
+
+ immediatePredictionStream.println(learningCurve.voteEntryToString());
+ immediatePredictionStream.flush();
+ }
+ }
+
private void concludeMeasurement() {
logger.info("last event is received!");
logger.info("total count: {}", this.totalCount);
@@ -192,6 +237,9 @@
if (immediateResultStream != null) {
immediateResultStream.println("# COMPLETED");
+ //
+ immediateResultStream
+ .println("# Total evaluation time: " + totalExperimentTime + " seconds for " + totalCount + " instances");
immediateResultStream.flush();
}
// logger.info("average throughput rate: {} instances/seconds",
@@ -203,6 +251,8 @@
private final PerformanceEvaluator evaluator;
private int samplingFrequency = 100000;
private File dumpFile = null;
+ private File predictionFile = null;
+ private int labelSamplingFrequency = 1;
public Builder(PerformanceEvaluator evaluator) {
this.evaluator = evaluator;
@@ -212,6 +262,8 @@
this.evaluator = oldProcessor.evaluator;
this.samplingFrequency = oldProcessor.samplingFrequency;
this.dumpFile = oldProcessor.dumpFile;
+ this.predictionFile = oldProcessor.predictionFile;
+ this.labelSamplingFrequency = oldProcessor.labelSamplingFrequency;
}
public Builder samplingFrequency(int samplingFrequency) {
@@ -224,6 +276,16 @@
return this;
}
+ public Builder predictionFile(File file) {
+ this.predictionFile = file;
+ return this;
+ }
+
+ public Builder labelSamplingFrequency(int samplingFrequency) {
+ this.labelSamplingFrequency = samplingFrequency;
+ return this;
+ }
+
public EvaluatorProcessor build() {
return new EvaluatorProcessor(this);
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java
index 89e74be..d54296d 100644
--- a/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java
@@ -1,5 +1,7 @@
package org.apache.samoa.evaluation;
+import org.apache.samoa.instances.Attribute;
+
/*
* #%L
* SAMOA
@@ -25,6 +27,7 @@
import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;
+import org.apache.samoa.moa.core.Vote;
import java.util.Collections;
import java.util.List;
@@ -44,7 +47,10 @@
protected long[] falsePos;
protected long[] trueNeg;
protected long[] falseNeg;
-
+ private String instanceIdentifier;
+ private Instance lastSeenInstance;
+ protected double[] classVotes;
+
@Override
public void reset() {
reset(this.numClasses);
@@ -67,7 +73,7 @@
}
@Override
- public void addResult(Instance inst, double[] classVotes) {
+ public void addResult(Instance inst, double[] classVotes, String instanceIndex) {
if (numClasses==-1) reset(inst.numClasses());
int trueClass = (int) inst.classValue();
this.support[trueClass] += 1;
@@ -95,6 +101,38 @@
Collections.addAll(measurements, getF1Measurements());
return measurements.toArray(new Measurement[measurements.size()]);
}
+
+ /**
+ * This method is used to retrieve predictions and votes (for classification only)
+ *
+ * @return String This returns an array of predictions and votes objects.
+ */
+ @Override
+ public Vote[] getPredictionVotes() {
+ Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute();
+ double trueValue = this.lastSeenInstance.classValue();
+ List<String> classAttributeValues = classAttribute.getAttributeValues();
+
+ int trueNominalIndex = (int) trueValue;
+ String trueNominalValue = classAttributeValues.get(trueNominalIndex);
+
+ Vote[] votes = new Vote[classVotes.length + 3];
+ votes[0] = new Vote("instance number",
+ this.instanceIdentifier);
+ votes[1] = new Vote("true class value",
+ trueNominalValue);
+ votes[2] = new Vote("predicted class value",
+ classAttributeValues.get(Utils.maxIndex(classVotes)));
+
+ for (int i = 0; i < classAttributeValues.size(); i++) {
+ if (i < classVotes.length) {
+ votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), classVotes[i]);
+ } else {
+ votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), 0);
+ }
+ }
+ return votes;
+ }
private Measurement[] getSupportMeasurements() {
Measurement[] measurements = new Measurement[this.numClasses];
diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java
index 0bd2450..c4c4a0b 100644
--- a/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java
+++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java
@@ -23,6 +23,7 @@
import org.apache.samoa.instances.Instance;
import org.apache.samoa.moa.MOAObject;
import org.apache.samoa.moa.core.Measurement;
+import org.apache.samoa.moa.core.Vote;
/**
* Interface implemented by learner evaluators to monitor the results of the learning process.
@@ -47,7 +48,7 @@
* an array containing the estimated membership probabilities of the test instance in each class
* @return an array of measurements monitored in this evaluator
*/
- public void addResult(Instance inst, double[] classVotes);
+ public void addResult(Instance inst, double[] classVotes, String instanceIdentifier);
/**
* Gets the current measurements monitored by this evaluator.
@@ -55,4 +56,11 @@
* @return an array of measurements monitored by this evaluator
*/
public Measurement[] getPerformanceMeasurements();
+
+ /**
+ * Gets the current votes monitored by this evaluator.
+ *
+ * @return an array of votes monitored by this evaluator
+ */
+ public Vote[] getPredictionVotes();
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java
index c428a7f..6ea40ed 100644
--- a/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java
@@ -1,5 +1,9 @@
package org.apache.samoa.evaluation;
+import java.util.List;
+
+import org.apache.samoa.instances.Attribute;
+
/*
* #%L
* SAMOA
@@ -24,6 +28,7 @@
import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;
+import org.apache.samoa.moa.core.Vote;
import com.github.javacliparser.IntOption;
@@ -59,6 +64,10 @@
protected int numClasses;
+ private String instanceIdentifier;
+ private Instance lastSeenInstance;
+ protected double[] classVotes;
+
public class Estimator {
protected double[] window;
@@ -127,7 +136,7 @@
}
@Override
- public void addResult(Instance inst, double[] classVotes) {
+ public void addResult(Instance inst, double[] classVotes, String instanceIndex) {
double weight = inst.weight();
int trueClass = (int) inst.classValue();
if (weight > 0.0) {
@@ -172,6 +181,38 @@
}
+ /**
+ * This method is used to retrieve predictions and votes (for classification only)
+ *
+ * @return String This returns an array of predictions and votes objects.
+ */
+ @Override
+ public Vote[] getPredictionVotes() {
+ Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute();
+ double trueValue = this.lastSeenInstance.classValue();
+ List<String> classAttributeValues = classAttribute.getAttributeValues();
+
+ int trueNominalIndex = (int) trueValue;
+ String trueNominalValue = classAttributeValues.get(trueNominalIndex);
+
+ Vote[] votes = new Vote[classVotes.length + 3];
+ votes[0] = new Vote("instance number",
+ this.instanceIdentifier);
+ votes[1] = new Vote("true class value",
+ trueNominalValue);
+ votes[2] = new Vote("predicted class value",
+ classAttributeValues.get(Utils.maxIndex(classVotes)));
+
+ for (int i = 0; i < classAttributeValues.size(); i++) {
+ if (i < classVotes.length) {
+ votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), classVotes[i]);
+ } else {
+ votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), 0);
+ }
+ }
+ return votes;
+ }
+
public double getTotalWeightObserved() {
return this.weightObserved.total();
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/core/Vote.java b/samoa-api/src/main/java/org/apache/samoa/moa/core/Vote.java
new file mode 100644
index 0000000..24ea3f3
--- /dev/null
+++ b/samoa-api/src/main/java/org/apache/samoa/moa/core/Vote.java
@@ -0,0 +1,86 @@
+package org.apache.samoa.moa.core;
+
+import java.io.Serializable;
+
+/*
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
+ * either express or implied. See the License for the specific
+ * language governing permissions and limitations under the
+ * License.
+ */
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+
+/**
+ * Class for storing votes.
+ *
+ */
+public class Vote implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ protected String name;
+ protected String value;
+
+ public Vote(String name) {
+ this.name = name;
+ }
+
+ public Vote(String name, String value) {
+ this.name = name;
+ this.value = value;
+ }
+
+ public Vote(String name, double value) {
+ this(name, value, 3);
+ }
+
+ public Vote(String name, double value, int fractionDigits) {
+ this(name);
+ setValue(value, fractionDigits);
+ }
+
+ public String getName() {
+ return this.name;
+ }
+
+ public String getValue() {
+ return this.value;
+ }
+
+ public void setValue(String value) {
+ this.value = value;
+ }
+
+ public void setValue(double value, int fractionDigits) {
+ // rely on dot as a decimal separator not to confuse CSV parsers
+ this.value = String.format(Locale.US, "%." + String.valueOf(fractionDigits) + "f", value);
+ }
+
+ public static void getVotesDescription(Vote[] votes,
+ StringBuilder out, int indent) {
+ if (votes.length > 0) {
+ StringUtils.appendIndented(out, indent, votes[0].toString());
+ for (int i = 1; i < votes.length; i++) {
+ StringUtils.appendNewlineIndented(out, indent, votes[i].toString());
+ }
+ }
+ }
+
+ public void getDescription(StringBuilder sb, int indent) {
+ sb.append(getName());
+ sb.append(" = ");
+ sb.append(this.value);
+ }
+}
diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java b/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java
index 427e01d..dcc4f50 100644
--- a/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java
+++ b/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java
@@ -27,6 +27,7 @@
import org.apache.samoa.moa.core.DoubleVector;
import org.apache.samoa.moa.core.Measurement;
import org.apache.samoa.moa.core.StringUtils;
+import org.apache.samoa.moa.core.Vote;
/**
* Class that stores and keeps the history of evaluation measurements.
@@ -40,8 +41,12 @@
protected List<String> measurementNames = new ArrayList<String>();
+ protected List<String> voteNames = new ArrayList<String>();
+
protected List<double[]> measurementValues = new ArrayList<double[]>();
+ protected List<String> voteValues = new ArrayList<String>();
+
public LearningCurve(String orderingMeasurementName) {
this.measurementNames.add(orderingMeasurementName);
}
@@ -129,4 +134,61 @@
public String getMeasurementName(int measurementIndex) {
return this.measurementNames.get(measurementIndex);
}
+
+ protected int addVoteName(String name) {
+ int index = this.voteNames.indexOf(name);
+ if (index < 0) {
+ index = this.voteNames.size();
+ this.voteNames.add(name);
+ }
+ return index;
+ }
+
+ public void setVote(Vote[] votes) {
+ this.voteValues.clear();
+ for (Vote vote : votes) {
+ voteValues.add(addVoteName(vote.getName()), vote.getValue());
+ }
+ }
+
+ /**
+ * This method is used to set generate header line of a text file containing predictions and votes (for classification
+ * only)
+ *
+ * @return String This returns the text of the header of a file containing predictions and votes.
+ */
+ public String voteHeaderToString() {
+ StringBuilder sb = new StringBuilder();
+ boolean first = true;
+ for (String name : this.voteNames) {
+ if (!first) {
+ sb.append(',');
+ } else {
+ first = false;
+ }
+ sb.append(name);
+ }
+ return sb.toString();
+ }
+
+ /**
+ * This method is used to set generate one body line of a text file containing predictions and votes (for
+ * classification only)
+ *
+ * @return String This returns the text of one line of a file containing predictions and votes.
+ */
+ public String voteEntryToString() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < this.voteNames.size(); i++) {
+ if (i > 0) {
+ sb.append(',');
+ }
+ if ((i >= voteValues.size())) {
+ sb.append('?');
+ } else {
+ sb.append(voteValues.get(i));
+ }
+ }
+ return sb.toString();
+ }
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java
index 001622b..dab505b 100644
--- a/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java
+++ b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java
@@ -84,6 +84,12 @@
"How many instances between samples of the learning performance.", 100000,
0, Integer.MAX_VALUE);
+ // The frequency of saving model output e.g. predicted class and votes made for individual classes to a file
+ // The name of the actual file to which model output will be saved is defined through resultFileOption
+ public IntOption labelSampleFrequencyOption = new IntOption("labelSampleFrequency", 'h',
+ "How many instances between samples of predicted labels and votes.", 1,
+ 0, Integer.MAX_VALUE);
+
public StringOption evaluationNameOption = new StringOption("evaluationName", 'n', "Identifier of the evaluation",
"Prequential_"
+ new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
@@ -91,6 +97,11 @@
public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to",
null, "csv", true);
+ // The name of the CSV file in which model output (and in the case of classification also votes for individual classes)
+ // will be saved
+ public FileOption resultFileOption = new FileOption("resultFile", 'g', "File to append intermediate model output to",
+ null, "csv", true);
+
// Default=0: no delay/waiting
public IntOption sourceDelayOption = new IntOption("sourceDelay", 'w',
"How many microseconds between injections of two instances.", 0, 0, Integer.MAX_VALUE);
@@ -167,7 +178,9 @@
evaluatorOptionValue = getDefaultPerformanceEvaluatorForLearner(classifier);
}
evaluator = new EvaluatorProcessor.Builder(evaluatorOptionValue)
- .samplingFrequency(sampleFrequencyOption.getValue()).dumpFile(dumpFileOption.getFile()).build();
+ .samplingFrequency(sampleFrequencyOption.getValue()).dumpFile(dumpFileOption.getFile())
+ .predictionFile(resultFileOption.getFile()).labelSamplingFrequency(labelSampleFrequencyOption.getValue())
+ .build();
// evaluatorPi = builder.createPi(evaluator);
// evaluatorPi.connectInputShuffleStream(evaluatorPiInputStream);
diff --git a/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java b/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java
index 52331c5..f621aba 100644
--- a/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java
+++ b/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java
@@ -22,6 +22,7 @@
import org.apache.samoa.LocalDoTask;
import org.junit.Test;
+import org.apache.samoa.TestParams;
public class AlgosTest {
@@ -32,6 +33,7 @@
.samplingSize(20_000)
.evaluationInstances(200_000)
.classifiedInstances(200_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(75f)
.kappaStat(0f)
.kappaTempStat(0f)
@@ -50,6 +52,7 @@
.samplingSize(20_000)
.evaluationInstances(200_000)
.classifiedInstances(200_000)
+ .labelSamplingSize(1l)
.classificationsCorrect(60f)
.kappaStat(0f)
.kappaTempStat(0f)
@@ -68,6 +71,7 @@
.samplingSize(20_000)
.evaluationInstances(200_000)
.classifiedInstances(200_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(65f)
.kappaStat(0f)
.kappaTempStat(0f)
@@ -93,6 +97,7 @@
.resultFilePollTimeout(10)
.prePollWait(10)
.taskClassName(LocalDoTask.class.getName())
+ .labelFileCreated(false)
.build();
TestUtils.test(vhtConfig);
}
diff --git a/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java b/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java
index d874e51..1c18eaf 100644
--- a/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java
+++ b/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java
@@ -35,9 +35,10 @@
.samplingSize(20_000)
.evaluationInstances(200_000)
.classifiedInstances(200_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(55f)
- .kappaStat(0f)
- .kappaTempStat(0f)
+ .kappaStat(-0.1f)
+ .kappaTempStat(-0.1f)
.cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE)
.resultFilePollTimeout(30)
.prePollWait(15)
@@ -54,6 +55,7 @@
.samplingSize(20_000)
.evaluationInstances(180_000)
.classifiedInstances(190_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(60f)
.kappaStat(0f)
.kappaTempStat(0f)
@@ -70,18 +72,19 @@
public void testCVPReqVHTWithStorm() throws Exception {
TestParams vhtConfig = new TestParams.Builder()
- .inputInstances(200_000)
- .samplingSize(20_000)
- .evaluationInstances(200_000)
- .classifiedInstances(200_000)
- .classificationsCorrect(55f)
- .kappaStat(0f)
- .kappaTempStat(0f)
- .cliStringTemplate(TestParams.Templates.PREQCVEVAL_VHT_RANDOMTREE)
- .resultFilePollTimeout(30)
- .prePollWait(15)
- .taskClassName(LocalStormDoTask.class.getName())
- .build();
+ .inputInstances(200_000)
+ .samplingSize(20_000)
+ .evaluationInstances(200_000)
+ .classifiedInstances(200_000)
+ .classificationsCorrect(55f)
+ .kappaStat(0f)
+ .kappaTempStat(0f)
+ .cliStringTemplate(TestParams.Templates.PREQCVEVAL_VHT_RANDOMTREE)
+ .resultFilePollTimeout(30)
+ .prePollWait(15)
+ .taskClassName(LocalStormDoTask.class.getName())
+ .labelFileCreated(false)
+ .build();
TestUtils.test(vhtConfig);
}
diff --git a/samoa-test/src/test/java/org/apache/samoa/TestParams.java b/samoa-test/src/test/java/org/apache/samoa/TestParams.java
index b066959..eb7e123 100644
--- a/samoa-test/src/test/java/org/apache/samoa/TestParams.java
+++ b/samoa-test/src/test/java/org/apache/samoa/TestParams.java
@@ -1,5 +1,7 @@
package org.apache.samoa;
+import org.apache.samoa.TestParams.Builder;
+
/*
* #%L
* SAMOA
@@ -32,20 +34,19 @@
* </ul>
* as well as the maximum number of instances for testing/training (-i) and the sampling size (-f)
*/
- public static class Templates {
-
- public final static String PREQEVAL_VHT_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
+ public static class Templates {
+ public final static String PREQEVAL_VHT_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d -g %s -h %d "
+ "-l (org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree -p 4) " +
"-s (org.apache.samoa.streams.generators.RandomTreeGenerator -c 2 -o 10 -u 10)";
- public final static String PREQEVAL_NAIVEBAYES_HYPERPLANE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
+ public final static String PREQEVAL_NAIVEBAYES_HYPERPLANE = "PrequentialEvaluation -d %s -i %d -f %d -w %d -g %s -h %d "
+ "-l (classifiers.SingleClassifier -l org.apache.samoa.learners.classifiers.NaiveBayes) " +
"-s (org.apache.samoa.streams.generators.HyperplaneGenerator -c 2)";
// setting the number of nominal attributes to zero significantly reduces
// the processing time,
// so that it's acceptable in a test case
- public final static String PREQEVAL_BAGGING_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
+ public final static String PREQEVAL_BAGGING_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d -g %s -h %d "
+ "-l (org.apache.samoa.learners.classifiers.ensemble.Bagging) " +
"-s (org.apache.samoa.streams.generators.RandomTreeGenerator -c 2 -o 0 -u 10)";
@@ -60,6 +61,11 @@
public static final String CLASSIFICATIONS_CORRECT = "classifications correct (percent)";
public static final String KAPPA_STAT = "Kappa Statistic (percent)";
public static final String KAPPA_TEMP_STAT = "Kappa Temporal Statistic (percent)";
+
+ public static final String INSTANCE_ID = "instance number";
+ public static final String TRUE_CLASS_VALUE = "true class value";
+ public static final String PREDICTED_CLASS_VALUE = "predicted class value";
+ public static final String VOTES = "votes";
private long inputInstances;
private long samplingSize;
@@ -73,6 +79,8 @@
private final int prePollWait;
private int inputDelayMicroSec;
private String taskClassName;
+ private boolean labelFileCreated;
+ private long labelSamplingSize;
private TestParams(String taskClassName,
long inputInstances,
@@ -85,7 +93,9 @@
String cliStringTemplate,
int pollTimeoutSeconds,
int prePollWait,
- int inputDelayMicroSec) {
+ int inputDelayMicroSec,
+ boolean labelFileCreated,
+ long labelSamplingSize) {
this.taskClassName = taskClassName;
this.inputInstances = inputInstances;
this.samplingSize = samplingSize;
@@ -98,6 +108,12 @@
this.pollTimeoutSeconds = pollTimeoutSeconds;
this.prePollWait = prePollWait;
this.inputDelayMicroSec = inputDelayMicroSec;
+ this.labelFileCreated = labelFileCreated;
+ this.labelSamplingSize = labelSamplingSize;
+ }
+
+ public boolean getLabelFileCreated() {
+ return labelFileCreated;
}
public String getTaskClassName() {
@@ -147,6 +163,10 @@
public int getInputDelayMicroSec() {
return inputDelayMicroSec;
}
+
+ public long getLabelSamplingSize() {
+ return labelSamplingSize;
+ }
@Override
public String toString() {
@@ -163,6 +183,8 @@
"prePollWait=" + prePollWait + "\n" +
"taskClassName='" + taskClassName + '\'' + "\n" +
"inputDelayMicroSec=" + inputDelayMicroSec + "\n" +
+ "labelFileCreated=" + labelFileCreated + "\n" +
+ "labelSamplingSize=" + labelSamplingSize + "\n" +
'}';
}
@@ -179,6 +201,8 @@
private int prePollWaitSeconds = 10;
private String taskClassName;
private int inputDelayMicroSec = 0;
+ private boolean labelFileCreated = true;
+ private long labelSamplingSize = 0l;
public Builder taskClassName(String taskClassName) {
this.taskClassName = taskClassName;
@@ -239,6 +263,16 @@
this.prePollWaitSeconds = prePollWaitSeconds;
return this;
}
+
+ public Builder labelFileCreated(boolean labelFileCreated) {
+ this.labelFileCreated = labelFileCreated;
+ return this;
+ }
+
+ public Builder labelSamplingSize(long labelSamplingSize) {
+ this.labelSamplingSize = labelSamplingSize;
+ return this;
+ }
public TestParams build() {
return new TestParams(taskClassName,
@@ -252,7 +286,9 @@
cliStringTemplate,
pollTimeoutSeconds,
prePollWaitSeconds,
- inputDelayMicroSec);
+ inputDelayMicroSec,
+ labelFileCreated,
+ labelSamplingSize);
}
}
}
diff --git a/samoa-test/src/test/java/org/apache/samoa/TestUtils.java b/samoa-test/src/test/java/org/apache/samoa/TestUtils.java
index 331f900..b5fef17 100644
--- a/samoa-test/src/test/java/org/apache/samoa/TestUtils.java
+++ b/samoa-test/src/test/java/org/apache/samoa/TestUtils.java
@@ -49,29 +49,29 @@
NoSuchMethodException, InvocationTargetException, IllegalAccessException, InterruptedException {
final File tempFile = File.createTempFile("test", "test");
-
- LOG.info("Starting test, output file is {}, test config is \n{}", tempFile.getAbsolutePath(), testParams.toString());
-
+ final File labelFile = File.createTempFile("result", "result");
+ LOG.info("Starting test, output file is {}, test config is \n{}", tempFile.getAbsolutePath(), testParams.toString());
Executors.newSingleThreadExecutor().submit(new Callable<Void>() {
-
@Override
public Void call() throws Exception {
- try {
+ try {
Class.forName(testParams.getTaskClassName())
- .getMethod("main", String[].class)
- .invoke(null, (Object) String.format(
- testParams.getCliStringTemplate(),
- tempFile.getAbsolutePath(),
- testParams.getInputInstances(),
- testParams.getSamplingSize(),
- testParams.getInputDelayMicroSec()
- ).split("[ ]"));
- } catch (Exception e) {
- LOG.error("Cannot execute test {} {}", e.getMessage(), e.getCause().getMessage());
+ .getMethod("main", String[].class)
+ .invoke(null, (Object) String.format(
+ testParams.getCliStringTemplate(),
+ tempFile.getAbsolutePath(),
+ testParams.getInputInstances(),
+ testParams.getSamplingSize(),
+ testParams.getInputDelayMicroSec(),
+ labelFile.getAbsolutePath(),
+ testParams.getLabelSamplingSize()
+ ).split("[ ]"));
+ } catch (Exception e) {
+ LOG.error("Cannot execute test {} {}", e.getMessage(), e.getCause().getMessage());
+ }
+ return null;
}
- return null;
- }
- });
+ });
Thread.sleep(TimeUnit.SECONDS.toMillis(testParams.getPrePollWaitSeconds()));
@@ -89,6 +89,8 @@
tailer.stop();
assertResults(tempFile, testParams);
+ if (testParams.getLabelFileCreated())
+ assertLabels(labelFile, testParams);
}
public static void assertResults(File outputFile, org.apache.samoa.TestParams testParams) throws IOException {
@@ -136,6 +138,32 @@
testParams.getKappaTempStat() <= Float.parseFloat(last.get(4 + 3 * cvEvaluation)));
}
+
+ public static void assertLabels(File labelFile, org.apache.samoa.TestParams testParams) throws IOException {
+ LOG.info("Checking labels file " + labelFile.getAbsolutePath());
+ //1. parse result file with csv parser
+ Reader in = new FileReader(labelFile);
+ long lineCount = 0;
+ long expectedLineCount = testParams.getInputInstances() / testParams.getLabelSamplingSize();
+ Iterable<CSVRecord> records = CSVFormat.EXCEL.withSkipHeaderRecord(false)
+ .withIgnoreEmptyLines(true).withDelimiter(',').withCommentMarker('#').parse(in);
+
+ Iterator<CSVRecord> iterator = records.iterator();
+ CSVRecord header = iterator.next();
+
+ while (iterator.hasNext()) {
+ iterator.next();
+ lineCount = lineCount + 1;
+ }
+
+ Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.INSTANCE_ID, header.get(0).trim());
+ Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.TRUE_CLASS_VALUE, header.get(1).trim());
+ Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.PREDICTED_CLASS_VALUE, header.get(2).trim());
+ for (int i = 3; i < header.size(); i++)
+ Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.VOTES, header.get(i).trim().substring(0, org.apache.samoa.TestParams.VOTES.length()));
+ Assert.assertEquals("Wrong number of lines in prediction file", expectedLineCount, lineCount);
+
+ }
private static class TestResultsTailerAdapter extends TailerListenerAdapter {
diff --git a/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java b/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java
index 031d98d..f43d667 100644
--- a/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java
+++ b/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java
@@ -35,6 +35,7 @@
.samplingSize(20_000)
.evaluationInstances(200_000)
.classifiedInstances(200_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(55f)
.kappaStat(-0.1f)
.kappaTempStat(-0.1f)
@@ -55,6 +56,7 @@
.inputDelayMicroSec(100) // prevents saturating the system due to unbounded queues
.evaluationInstances(90_000)
.classifiedInstances(100_000)
+ .labelSamplingSize(10l)
.classificationsCorrect(55f)
.kappaStat(0f)
.kappaTempStat(0f)