SAMOA-34: Fix AdaptiveBagging
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java b/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java
index 28d0059..54af7b6 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java
@@ -25,14 +25,13 @@
*/
import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector;
-import org.apache.samoa.topology.Stream;
/**
* The Interface Adaptive Learner. Initializing Classifier should initalize PI to connect the Classifier with the input
* stream and initialize result stream so that other PI can connect to the classification result of this classifier
*/
-public interface AdaptiveLearner {
+public interface AdaptiveLearner extends Learner{
/**
* Gets the change detector item.
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java
index 9ffba2a..4b2c531 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java
@@ -45,19 +45,16 @@
import com.github.javacliparser.IntOption;
/**
- * The Bagging Classifier by Oza and Russell.
+ * An adaptive version of the Bagging Classifier by Oza and Russell.
*/
public class AdaptiveBagging implements Learner, Configurable {
- /** Logger */
+ private static final long serialVersionUID = 8217274236558839040L;
private static final Logger logger = LoggerFactory.getLogger(AdaptiveBagging.class);
- /** The Constant serialVersionUID. */
- private static final long serialVersionUID = -2971850264864952099L;
-
/** The base learner option. */
public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
- "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
+ "Classifier to train.", AdaptiveLearner.class, VerticalHoeffdingTree.class.getName());
/** The ensemble size option. */
public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
@@ -69,59 +66,63 @@
/** The distributor processor. */
private BaggingDistributorProcessor distributorP;
+ /** The input streams for the ensemble, one per member. */
+ private Stream[] ensembleStreams;
+
/** The result stream. */
protected Stream resultStream;
/** The dataset. */
private Instances dataset;
- protected Learner classifier;
-
- protected int parallelism;
+ protected AdaptiveLearner[] ensemble;
/**
* Sets the layout.
*/
protected void setLayout() {
-
- int sizeEnsemble = this.ensembleSizeOption.getValue();
+ int ensembleSize = this.ensembleSizeOption.getValue();
distributorP = new BaggingDistributorProcessor();
- distributorP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(distributorP, 1);
+ distributorP.setEnsembleSize(ensembleSize);
+ builder.addProcessor(distributorP, 1);
// instantiate classifier
- classifier = this.baseLearnerOption.getValue();
- if (classifier instanceof AdaptiveLearner) {
- // logger.info("Building an AdaptiveLearner {}",
- // classifier.getClass().getName());
- AdaptiveLearner ada = (AdaptiveLearner) classifier;
- ada.setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue());
+ ensemble = new AdaptiveLearner[ensembleSize];
+ for (int i = 0; i < ensembleSize; i++) {
+ try {
+ ensemble[i] = (AdaptiveLearner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(),
+ baseLearnerOption.getRequiredType());
+ } catch (Exception e) {
+ logger.error("Unable to create members of the ensemble. Please check your CLI parameters");
+ e.printStackTrace();
+ throw new IllegalArgumentException(e);
+ }
+ ensemble[i].setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue());
+ ensemble[i].init(builder, this.dataset, 1); // sequential
}
- classifier.init(builder, this.dataset, sizeEnsemble);
PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor();
- predictionCombinerP.setSizeEnsemble(sizeEnsemble);
+ predictionCombinerP.setEnsembleSize(ensembleSize);
this.builder.addProcessor(predictionCombinerP, 1);
// Streams
- resultStream = this.builder.createStream(predictionCombinerP);
+ resultStream = builder.createStream(predictionCombinerP);
predictionCombinerP.setOutputStream(resultStream);
- for (Stream subResultStream : classifier.getResultStreams()) {
- this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
+ for (AdaptiveLearner member : ensemble) {
+ for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams
+ this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); // the key is the instance id to combine predictions
+ }
}
- /* The training stream. */
- Stream testingStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
+ ensembleStreams = new Stream[ensembleSize];
+ for (int i = 0; i < ensembleSize; i++) {
+ ensembleStreams[i] = builder.createStream(distributorP);
+ builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter)
+ }
- /* The prediction stream. */
- Stream predictionStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
-
- distributorP.setOutputStream(testingStream);
- distributorP.setPredictionStream(predictionStream);
+ distributorP.setOutputStreams(ensembleStreams);
}
/** The builder. */
@@ -131,7 +132,6 @@
public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
this.builder = builder;
this.dataset = dataset;
- this.parallelism = parallelism;
this.setLayout();
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
index 43bc07c..5d7bbfc 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
@@ -71,8 +71,6 @@
protected Learner[] ensemble;
- protected int parallelism;
-
/**
* Sets the layout.
*
@@ -129,7 +127,6 @@
public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
this.builder = builder;
this.dataset = dataset;
- this.parallelism = parallelism;
this.setLayout();
}