SAMOA-34: Fix Bagging
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
index 7355b1a..43bc07c 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
@@ -24,8 +24,6 @@
* License
*/
-import com.google.common.collect.ImmutableSet;
-
import java.util.Set;
import org.apache.samoa.core.Processor;
@@ -34,10 +32,13 @@
import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree;
import org.apache.samoa.topology.Stream;
import org.apache.samoa.topology.TopologyBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import com.github.javacliparser.ClassOption;
import com.github.javacliparser.Configurable;
import com.github.javacliparser.IntOption;
+import com.google.common.collect.ImmutableSet;
/**
* The Bagging Classifier by Oza and Russell.
@@ -46,6 +47,7 @@
/** The Constant serialVersionUID. */
private static final long serialVersionUID = -2971850264864952099L;
+ private static final Logger logger = LoggerFactory.getLogger(Bagging.class);
/** The base learner option. */
public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
@@ -58,11 +60,8 @@
/** The distributor processor. */
private BaggingDistributorProcessor distributorP;
- /** The training stream. */
- private Stream testingStream;
-
- /** The prediction stream. */
- private Stream predictionStream;
+ /** The input streams for the ensemble, one per member. */
+ private Stream[] ensembleStreams;
/** The result stream. */
protected Stream resultStream;
@@ -70,45 +69,57 @@
/** The dataset. */
private Instances dataset;
- protected Learner classifier;
+ protected Learner[] ensemble;
protected int parallelism;
/**
* Sets the layout.
+ *
+ * @throws Exception
*/
protected void setLayout() {
-
- int sizeEnsemble = this.ensembleSizeOption.getValue();
+ int ensembleSize = this.ensembleSizeOption.getValue();
distributorP = new BaggingDistributorProcessor();
- distributorP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(distributorP, 1);
+ distributorP.setEnsembleSize(ensembleSize);
+ builder.addProcessor(distributorP, 1);
// instantiate classifier
- classifier = (Learner) this.baseLearnerOption.getValue();
- classifier.init(builder, this.dataset, sizeEnsemble);
+ ensemble = new Learner[ensembleSize];
+ for (int i = 0; i < ensembleSize; i++) {
+ try {
+ ensemble[i] = (Learner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(),
+ baseLearnerOption.getRequiredType());
+ } catch (Exception e) {
+ logger.error("Unable to create members of the ensemble. Please check your CLI parameters");
+ e.printStackTrace();
+ throw new IllegalArgumentException(e);
+ }
+ ensemble[i].init(builder, this.dataset, 1); // sequential
+ }
PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor();
- predictionCombinerP.setSizeEnsemble(sizeEnsemble);
+ predictionCombinerP.setEnsembleSize(ensembleSize);
this.builder.addProcessor(predictionCombinerP, 1);
// Streams
- resultStream = this.builder.createStream(predictionCombinerP);
+ resultStream = builder.createStream(predictionCombinerP);
predictionCombinerP.setOutputStream(resultStream);
- for (Stream subResultStream : classifier.getResultStreams()) {
- this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
+ for (Learner member : ensemble) {
+ for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams
+ this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); // the key is the instance id to combine predictions
+ }
}
- testingStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
+ ensembleStreams = new Stream[ensembleSize];
+ for (int i = 0; i < ensembleSize; i++) {
+ ensembleStreams[i] = builder.createStream(distributorP);
+ builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter)
+ }
- predictionStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
-
- distributorP.setOutputStream(testingStream);
- distributorP.setPredictionStream(predictionStream);
+ distributorP.setOutputStreams(ensembleStreams);
}
/** The builder. */
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
index 33615db..6c88d94 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
@@ -24,6 +24,7 @@
* License
*/
+import java.util.Arrays;
import java.util.Random;
import org.apache.samoa.core.ContentEvent;
@@ -38,19 +39,16 @@
*/
public class BaggingDistributorProcessor implements Processor {
- /**
- *
- */
private static final long serialVersionUID = -1550901409625192730L;
- /** The size ensemble. */
- private int sizeEnsemble;
+ /** The ensemble size. */
+ private int ensembleSize;
- /** The training stream. */
- private Stream trainingStream;
+ /** The stream ensemble. */
+ private Stream[] ensembleStreams;
- /** The prediction stream. */
- private Stream predictionStream;
+ /** Ramdom number generator. */
+ protected Random random = new Random(); //TODO make random seed configurable
/**
* On event.
@@ -60,38 +58,34 @@
* @return true, if successful
*/
public boolean process(ContentEvent event) {
- InstanceContentEvent inEvent = (InstanceContentEvent) event; // ((s4Event)event).getContentEvent();
- // InstanceEvent inEvent = (InstanceEvent) event;
+ InstanceContentEvent inEvent = (InstanceContentEvent) event;
if (inEvent.getInstanceIndex() < 0) {
- // End learning
- predictionStream.put(event);
+ // end learning
+ for (Stream stream : ensembleStreams)
+ stream.put(event);
return false;
}
if (inEvent.isTesting()) {
- Instance trainInst = inEvent.getInstance();
- for (int i = 0; i < sizeEnsemble; i++) {
- Instance weightedInst = trainInst.copy();
- // weightedInst.setWeight(trainInst.weight() * k);
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
- inEvent.getInstanceIndex(), weightedInst, false, true);
- instanceContentEvent.setClassifierIndex(i);
- instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- predictionStream.put(instanceContentEvent);
+ Instance testInstance = inEvent.getInstance();
+ for (int i = 0; i < ensembleSize; i++) {
+ Instance instanceCopy = testInstance.copy();
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy,
+ false, true);
+ instanceContentEvent.setClassifierIndex(i); //TODO probably not needed anymore
+ instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); //TODO probably not needed anymore
+ ensembleStreams[i].put(instanceContentEvent);
}
}
- /* Estimate model parameters using the training data. */
+ // estimate model parameters using the training data
if (inEvent.isTraining()) {
train(inEvent);
}
- return false;
+ return true;
}
- /** The random. */
- protected Random random = new Random();
-
/**
* Train.
*
@@ -99,104 +93,51 @@
* the in event
*/
protected void train(InstanceContentEvent inEvent) {
- Instance trainInst = inEvent.getInstance();
- for (int i = 0; i < sizeEnsemble; i++) {
+ Instance trainInstance = inEvent.getInstance();
+ for (int i = 0; i < ensembleSize; i++) {
int k = MiscUtils.poisson(1.0, this.random);
if (k > 0) {
- Instance weightedInst = trainInst.copy();
- weightedInst.setWeight(trainInst.weight() * k);
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
- inEvent.getInstanceIndex(), weightedInst, true, false);
+ Instance weightedInstance = trainInstance.copy();
+ weightedInstance.setWeight(trainInstance.weight() * k);
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(),
+ weightedInstance, true, false);
instanceContentEvent.setClassifierIndex(i);
instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- trainingStream.put(instanceContentEvent);
+ ensembleStreams[i].put(instanceContentEvent);
}
}
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.ProcessingElement#onCreate()
- */
@Override
public void onCreate(int id) {
// do nothing
}
- /**
- * Gets the training stream.
- *
- * @return the training stream
- */
- public Stream getTrainingStream() {
- return trainingStream;
+ public Stream[] getOutputStreams() {
+ return ensembleStreams;
}
- /**
- * Sets the training stream.
- *
- * @param trainingStream
- * the new training stream
- */
- public void setOutputStream(Stream trainingStream) {
- this.trainingStream = trainingStream;
+ public void setOutputStreams(Stream[] ensembleStreams) {
+ this.ensembleStreams = ensembleStreams;
}
- /**
- * Gets the prediction stream.
- *
- * @return the prediction stream
- */
- public Stream getPredictionStream() {
- return predictionStream;
+ public int getEnsembleSize() {
+ return ensembleSize;
}
- /**
- * Sets the prediction stream.
- *
- * @param predictionStream
- * the new prediction stream
- */
- public void setPredictionStream(Stream predictionStream) {
- this.predictionStream = predictionStream;
+ public void setEnsembleSize(int ensembleSize) {
+ this.ensembleSize = ensembleSize;
}
- /**
- * Gets the size ensemble.
- *
- * @return the size ensemble
- */
- public int getSizeEnsemble() {
- return sizeEnsemble;
- }
-
- /**
- * Sets the size ensemble.
- *
- * @param sizeEnsemble
- * the new size ensemble
- */
- public void setSizeEnsemble(int sizeEnsemble) {
- this.sizeEnsemble = sizeEnsemble;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
- */
@Override
public Processor newProcessor(Processor sourceProcessor) {
BaggingDistributorProcessor newProcessor = new BaggingDistributorProcessor();
BaggingDistributorProcessor originProcessor = (BaggingDistributorProcessor) sourceProcessor;
- if (originProcessor.getPredictionStream() != null) {
- newProcessor.setPredictionStream(originProcessor.getPredictionStream());
+ if (originProcessor.getOutputStreams() != null) {
+ newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(),
+ originProcessor.getOutputStreams().length));
}
- if (originProcessor.getTrainingStream() != null) {
- newProcessor.setOutputStream(originProcessor.getTrainingStream());
- }
- newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
+ newProcessor.setEnsembleSize(originProcessor.getEnsembleSize());
/*
* if (originProcessor.getLearningCurve() != null){
* newProcessor.setLearningCurve((LearningCurve)
@@ -204,5 +145,4 @@
*/
return newProcessor;
}
-
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
index 2e5f335..76e84f8 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
@@ -33,14 +33,14 @@
import org.apache.samoa.topology.Stream;
/**
- * The Class PredictionCombinerProcessor.
+ * Combines predictions coming from an ensemble. Equivalent to a majority-vote classifier.
*/
public class PredictionCombinerProcessor implements Processor {
private static final long serialVersionUID = -1606045723451191132L;
/**
- * The size ensemble.
+ * The ensemble size.
*/
protected int ensembleSize;
@@ -73,7 +73,7 @@
*
* @return the ensembleSize
*/
- public int getSizeEnsemble() {
+ public int getEnsembleSize() {
return ensembleSize;
}
@@ -83,7 +83,7 @@
* @param ensembleSize
* the new size ensemble
*/
- public void setSizeEnsemble(int ensembleSize) {
+ public void setEnsembleSize(int ensembleSize) {
this.ensembleSize = ensembleSize;
}
@@ -143,7 +143,7 @@
if (originProcessor.getOutputStream() != null) {
newProcessor.setOutputStream(originProcessor.getOutputStream());
}
- newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
+ newProcessor.setEnsembleSize(originProcessor.getEnsembleSize());
return newProcessor;
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
index ea7e53d..6534cee 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
@@ -20,8 +20,6 @@
* #L%
*/
-import com.google.common.collect.ImmutableSet;
-
import java.util.Set;
import org.apache.samoa.core.Processor;
@@ -41,6 +39,7 @@
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
+import com.google.common.collect.ImmutableSet;
/**
* Vertical Hoeffding Tree.
@@ -172,7 +171,7 @@
public void setChangeDetector(ChangeDetector cd) {
this.changeDetector = cd;
}
-
+
static class LearningNodeIdGenerator {
// TODO: add code to warn user of when value reaches Long.MAX_VALUES