SAMOA-34: Fix Bagging
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
index 7355b1a..43bc07c 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
@@ -24,8 +24,6 @@
  * License
  */
 
-import com.google.common.collect.ImmutableSet;
-
 import java.util.Set;
 
 import org.apache.samoa.core.Processor;
@@ -34,10 +32,13 @@
 import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree;
 import org.apache.samoa.topology.Stream;
 import org.apache.samoa.topology.TopologyBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import com.github.javacliparser.ClassOption;
 import com.github.javacliparser.Configurable;
 import com.github.javacliparser.IntOption;
+import com.google.common.collect.ImmutableSet;
 
 /**
  * The Bagging Classifier by Oza and Russell.
@@ -46,6 +47,7 @@
 
   /** The Constant serialVersionUID. */
   private static final long serialVersionUID = -2971850264864952099L;
+  private static final Logger logger = LoggerFactory.getLogger(Bagging.class);
 
   /** The base learner option. */
   public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
@@ -58,11 +60,8 @@
   /** The distributor processor. */
   private BaggingDistributorProcessor distributorP;
 
-  /** The training stream. */
-  private Stream testingStream;
-
-  /** The prediction stream. */
-  private Stream predictionStream;
+  /** The input streams for the ensemble, one per member. */
+  private Stream[] ensembleStreams;
 
   /** The result stream. */
   protected Stream resultStream;
@@ -70,45 +69,57 @@
   /** The dataset. */
   private Instances dataset;
 
-  protected Learner classifier;
+  protected Learner[] ensemble;
 
   protected int parallelism;
 
   /**
    * Sets the layout.
+   * 
+   * @throws Exception
    */
   protected void setLayout() {
-
-    int sizeEnsemble = this.ensembleSizeOption.getValue();
+    int ensembleSize = this.ensembleSizeOption.getValue();
 
     distributorP = new BaggingDistributorProcessor();
-    distributorP.setSizeEnsemble(sizeEnsemble);
-    this.builder.addProcessor(distributorP, 1);
+    distributorP.setEnsembleSize(ensembleSize);
+    builder.addProcessor(distributorP, 1);
 
     // instantiate classifier
-    classifier = (Learner) this.baseLearnerOption.getValue();
-    classifier.init(builder, this.dataset, sizeEnsemble);
+    ensemble = new Learner[ensembleSize];
+    for (int i = 0; i < ensembleSize; i++) {
+      try {
+        ensemble[i] = (Learner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(),
+            baseLearnerOption.getRequiredType());
+      } catch (Exception e) {
+        logger.error("Unable to create members of the ensemble. Please check your CLI parameters");
+        e.printStackTrace();
+        throw new IllegalArgumentException(e);
+      }
+      ensemble[i].init(builder, this.dataset, 1); // sequential
+    }
 
     PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor();
-    predictionCombinerP.setSizeEnsemble(sizeEnsemble);
+    predictionCombinerP.setEnsembleSize(ensembleSize);
     this.builder.addProcessor(predictionCombinerP, 1);
 
     // Streams
-    resultStream = this.builder.createStream(predictionCombinerP);
+    resultStream = builder.createStream(predictionCombinerP);
     predictionCombinerP.setOutputStream(resultStream);
 
-    for (Stream subResultStream : classifier.getResultStreams()) {
-      this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
+    for (Learner member : ensemble) {
+      for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams
+        this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); // the key is the instance id to combine predictions
+      }
     }
 
-    testingStream = this.builder.createStream(distributorP);
-    this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
+    ensembleStreams = new Stream[ensembleSize];
+    for (int i = 0; i < ensembleSize; i++) {
+      ensembleStreams[i] = builder.createStream(distributorP);
+      builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter)
+    }
 
-    predictionStream = this.builder.createStream(distributorP);
-    this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
-
-    distributorP.setOutputStream(testingStream);
-    distributorP.setPredictionStream(predictionStream);
+    distributorP.setOutputStreams(ensembleStreams);
   }
 
   /** The builder. */
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
index 33615db..6c88d94 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
@@ -24,6 +24,7 @@
  * License
  */
 
+import java.util.Arrays;
 import java.util.Random;
 
 import org.apache.samoa.core.ContentEvent;
@@ -38,19 +39,16 @@
  */
 public class BaggingDistributorProcessor implements Processor {
 
-  /**
-	 * 
-	 */
   private static final long serialVersionUID = -1550901409625192730L;
 
-  /** The size ensemble. */
-  private int sizeEnsemble;
+  /** The ensemble size. */
+  private int ensembleSize;
 
-  /** The training stream. */
-  private Stream trainingStream;
+  /** The stream ensemble. */
+  private Stream[] ensembleStreams;
 
-  /** The prediction stream. */
-  private Stream predictionStream;
+  /** Ramdom number generator. */
+  protected Random random = new Random(); //TODO make random seed configurable
 
   /**
    * On event.
@@ -60,38 +58,34 @@
    * @return true, if successful
    */
   public boolean process(ContentEvent event) {
-    InstanceContentEvent inEvent = (InstanceContentEvent) event; // ((s4Event)event).getContentEvent();
-    // InstanceEvent inEvent = (InstanceEvent) event;
+    InstanceContentEvent inEvent = (InstanceContentEvent) event;
 
     if (inEvent.getInstanceIndex() < 0) {
-      // End learning
-      predictionStream.put(event);
+      // end learning
+      for (Stream stream : ensembleStreams)
+        stream.put(event);
       return false;
     }
 
     if (inEvent.isTesting()) {
-      Instance trainInst = inEvent.getInstance();
-      for (int i = 0; i < sizeEnsemble; i++) {
-        Instance weightedInst = trainInst.copy();
-        // weightedInst.setWeight(trainInst.weight() * k);
-        InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
-            inEvent.getInstanceIndex(), weightedInst, false, true);
-        instanceContentEvent.setClassifierIndex(i);
-        instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
-        predictionStream.put(instanceContentEvent);
+      Instance testInstance = inEvent.getInstance();
+      for (int i = 0; i < ensembleSize; i++) {
+        Instance instanceCopy = testInstance.copy();
+        InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy,
+            false, true);
+        instanceContentEvent.setClassifierIndex(i); //TODO probably not needed anymore
+        instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); //TODO probably not needed anymore
+        ensembleStreams[i].put(instanceContentEvent);
       }
     }
 
-    /* Estimate model parameters using the training data. */
+    // estimate model parameters using the training data
     if (inEvent.isTraining()) {
       train(inEvent);
     }
-    return false;
+    return true;
   }
 
-  /** The random. */
-  protected Random random = new Random();
-
   /**
    * Train.
    * 
@@ -99,104 +93,51 @@
    *          the in event
    */
   protected void train(InstanceContentEvent inEvent) {
-    Instance trainInst = inEvent.getInstance();
-    for (int i = 0; i < sizeEnsemble; i++) {
+    Instance trainInstance = inEvent.getInstance();
+    for (int i = 0; i < ensembleSize; i++) {
       int k = MiscUtils.poisson(1.0, this.random);
       if (k > 0) {
-        Instance weightedInst = trainInst.copy();
-        weightedInst.setWeight(trainInst.weight() * k);
-        InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
-            inEvent.getInstanceIndex(), weightedInst, true, false);
+        Instance weightedInstance = trainInstance.copy();
+        weightedInstance.setWeight(trainInstance.weight() * k);
+        InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(),
+            weightedInstance, true, false);
         instanceContentEvent.setClassifierIndex(i);
         instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
-        trainingStream.put(instanceContentEvent);
+        ensembleStreams[i].put(instanceContentEvent);
       }
     }
   }
 
-  /*
-   * (non-Javadoc)
-   * 
-   * @see org.apache.s4.core.ProcessingElement#onCreate()
-   */
   @Override
   public void onCreate(int id) {
     // do nothing
   }
 
-  /**
-   * Gets the training stream.
-   * 
-   * @return the training stream
-   */
-  public Stream getTrainingStream() {
-    return trainingStream;
+  public Stream[] getOutputStreams() {
+    return ensembleStreams;
   }
 
-  /**
-   * Sets the training stream.
-   * 
-   * @param trainingStream
-   *          the new training stream
-   */
-  public void setOutputStream(Stream trainingStream) {
-    this.trainingStream = trainingStream;
+  public void setOutputStreams(Stream[] ensembleStreams) {
+    this.ensembleStreams = ensembleStreams;
   }
 
-  /**
-   * Gets the prediction stream.
-   * 
-   * @return the prediction stream
-   */
-  public Stream getPredictionStream() {
-    return predictionStream;
+  public int getEnsembleSize() {
+    return ensembleSize;
   }
 
-  /**
-   * Sets the prediction stream.
-   * 
-   * @param predictionStream
-   *          the new prediction stream
-   */
-  public void setPredictionStream(Stream predictionStream) {
-    this.predictionStream = predictionStream;
+  public void setEnsembleSize(int ensembleSize) {
+    this.ensembleSize = ensembleSize;
   }
 
-  /**
-   * Gets the size ensemble.
-   * 
-   * @return the size ensemble
-   */
-  public int getSizeEnsemble() {
-    return sizeEnsemble;
-  }
-
-  /**
-   * Sets the size ensemble.
-   * 
-   * @param sizeEnsemble
-   *          the new size ensemble
-   */
-  public void setSizeEnsemble(int sizeEnsemble) {
-    this.sizeEnsemble = sizeEnsemble;
-  }
-
-  /*
-   * (non-Javadoc)
-   * 
-   * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
-   */
   @Override
   public Processor newProcessor(Processor sourceProcessor) {
     BaggingDistributorProcessor newProcessor = new BaggingDistributorProcessor();
     BaggingDistributorProcessor originProcessor = (BaggingDistributorProcessor) sourceProcessor;
-    if (originProcessor.getPredictionStream() != null) {
-      newProcessor.setPredictionStream(originProcessor.getPredictionStream());
+    if (originProcessor.getOutputStreams() != null) {
+      newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(),
+          originProcessor.getOutputStreams().length));
     }
-    if (originProcessor.getTrainingStream() != null) {
-      newProcessor.setOutputStream(originProcessor.getTrainingStream());
-    }
-    newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
+    newProcessor.setEnsembleSize(originProcessor.getEnsembleSize());
     /*
      * if (originProcessor.getLearningCurve() != null){
      * newProcessor.setLearningCurve((LearningCurve)
@@ -204,5 +145,4 @@
      */
     return newProcessor;
   }
-
 }
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
index 2e5f335..76e84f8 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
@@ -33,14 +33,14 @@
 import org.apache.samoa.topology.Stream;
 
 /**
- * The Class PredictionCombinerProcessor.
+ * Combines predictions coming from an ensemble. Equivalent to a majority-vote classifier.
  */
 public class PredictionCombinerProcessor implements Processor {
 
   private static final long serialVersionUID = -1606045723451191132L;
 
   /**
-   * The size ensemble.
+   * The ensemble size.
    */
   protected int ensembleSize;
 
@@ -73,7 +73,7 @@
    * 
    * @return the ensembleSize
    */
-  public int getSizeEnsemble() {
+  public int getEnsembleSize() {
     return ensembleSize;
   }
 
@@ -83,7 +83,7 @@
    * @param ensembleSize
    *          the new size ensemble
    */
-  public void setSizeEnsemble(int ensembleSize) {
+  public void setEnsembleSize(int ensembleSize) {
     this.ensembleSize = ensembleSize;
   }
 
@@ -143,7 +143,7 @@
     if (originProcessor.getOutputStream() != null) {
       newProcessor.setOutputStream(originProcessor.getOutputStream());
     }
-    newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
+    newProcessor.setEnsembleSize(originProcessor.getEnsembleSize());
     return newProcessor;
   }
 
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
index ea7e53d..6534cee 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
@@ -20,8 +20,6 @@
  * #L%
  */
 
-import com.google.common.collect.ImmutableSet;
-
 import java.util.Set;
 
 import org.apache.samoa.core.Processor;
@@ -41,6 +39,7 @@
 import com.github.javacliparser.FlagOption;
 import com.github.javacliparser.FloatOption;
 import com.github.javacliparser.IntOption;
+import com.google.common.collect.ImmutableSet;
 
 /**
  * Vertical Hoeffding Tree.
@@ -172,7 +171,7 @@
   public void setChangeDetector(ChangeDetector cd) {
     this.changeDetector = cd;
   }
-
+  
   static class LearningNodeIdGenerator {
 
     // TODO: add code to warn user of when value reaches Long.MAX_VALUES