SAMOA-35: Add Sharding ensemble method
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
index 7178738..967684f 100644
--- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
@@ -143,7 +143,6 @@
*/
@Override
public Set<Stream> getResultStreams() {
- Set<Stream> streams = ImmutableSet.of(this.resultStream);
- return streams;
+ return ImmutableSet.of(this.resultStream);
}
}
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java
new file mode 100644
index 0000000..588d9f2
--- /dev/null
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java
@@ -0,0 +1,142 @@
+package org.apache.samoa.learners.classifiers.ensemble;
+
+/*
+ * #%L
+ * SAMOA
+ * %%
+ * Copyright (C) 2014 - 2015 Apache Software Foundation
+ * %%
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * #L%
+ */
+
+import java.util.Set;
+
+import org.apache.samoa.core.Processor;
+import org.apache.samoa.instances.Instances;
+import org.apache.samoa.learners.Learner;
+import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree;
+import org.apache.samoa.topology.Stream;
+import org.apache.samoa.topology.TopologyBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.github.javacliparser.ClassOption;
+import com.github.javacliparser.Configurable;
+import com.github.javacliparser.IntOption;
+import com.google.common.collect.ImmutableSet;
+
+/**
+ * Simple sharding meta-classifier. It trains an ensemble of learners by shuffling the training stream among them, so
+ * that each learner is completely independent from each other.
+ */
+public class Sharding implements Learner, Configurable {
+
+ private static final long serialVersionUID = -2971850264864952099L;
+ private static final Logger logger = LoggerFactory.getLogger(Sharding.class);
+
+ /** The base learner class. */
+ public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
+ "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
+
+ /** The ensemble size option. */
+ public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
+ "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
+
+ /** The distributor processor. */
+ private ShardingDistributorProcessor distributor;
+
+ /** The input streams for the ensemble, one per member. */
+ private Stream[] ensembleStreams;
+
+ /** The result stream. */
+ protected Stream resultStream;
+
+ /** The dataset. */
+ private Instances dataset;
+
+ protected Learner[] ensemble;
+
+ /**
+ * Sets the layout.
+ */
+ protected void setLayout() {
+
+ int ensembleSize = this.ensembleSizeOption.getValue();
+
+ distributor = new ShardingDistributorProcessor();
+ distributor.setEnsembleSize(ensembleSize);
+ this.builder.addProcessor(distributor, 1);
+
+ // instantiate classifier
+ ensemble = new Learner[ensembleSize];
+ for (int i = 0; i < ensembleSize; i++) {
+ try {
+ ensemble[i] = (Learner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(),
+ baseLearnerOption.getRequiredType());
+ } catch (Exception e) {
+ logger.error("Unable to create members of the ensemble. Please check your CLI parameters");
+ e.printStackTrace();
+ throw new IllegalArgumentException(e);
+ }
+ ensemble[i].init(builder, this.dataset, 1); // sequential
+ }
+
+ PredictionCombinerProcessor predictionCombiner = new PredictionCombinerProcessor();
+ predictionCombiner.setEnsembleSize(ensembleSize);
+ this.builder.addProcessor(predictionCombiner, 1);
+
+ // Streams
+ resultStream = this.builder.createStream(predictionCombiner);
+ predictionCombiner.setOutputStream(resultStream);
+
+ for (Learner member : ensemble) {
+ for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams
+ this.builder.connectInputKeyStream(subResultStream, predictionCombiner); // the key is the instance id to combine predictions
+ }
+ }
+
+ ensembleStreams = new Stream[ensembleSize];
+ for (int i = 0; i < ensembleSize; i++) {
+ ensembleStreams[i] = builder.createStream(distributor);
+ builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter)
+ }
+
+ distributor.setOutputStreams(ensembleStreams);
+ }
+
+ /** The builder. */
+ private TopologyBuilder builder;
+
+ @Override
+ public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
+ this.builder = builder;
+ this.dataset = dataset;
+ this.setLayout();
+ }
+
+ @Override
+ public Processor getInputProcessor() {
+ return distributor;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.learners.Learner#getResultStreams()
+ */
+ @Override
+ public Set<Stream> getResultStreams() {
+ return ImmutableSet.of(this.resultStream);
+ }
+}
diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java
new file mode 100644
index 0000000..0e936d7
--- /dev/null
+++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java
@@ -0,0 +1,161 @@
+package org.apache.samoa.learners.classifiers.ensemble;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/*
+ * #%L
+ * SAMOA
+ * %%
+ * Copyright (C) 2014 - 2015 Apache Software Foundation
+ * %%
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * #L%
+ */
+
+/**
+ * License
+ */
+
+import org.apache.samoa.core.ContentEvent;
+import org.apache.samoa.core.Processor;
+import org.apache.samoa.instances.Instance;
+import org.apache.samoa.learners.InstanceContentEvent;
+import org.apache.samoa.topology.Stream;
+
+/**
+ * The Class BaggingDistributorPE.
+ */
+public class ShardingDistributorProcessor implements Processor {
+
+ private static final long serialVersionUID = -1550901409625192730L;
+
+ /** The ensemble size. */
+ private int ensembleSize;
+
+ /** The stream ensemble. */
+ private Stream[] ensembleStreams;
+
+ /** Ramdom number generator. */
+ protected Random random = new Random(); //TODO make random seed configurable
+
+ /**
+ * On event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ public boolean process(ContentEvent event) {
+ InstanceContentEvent inEvent = (InstanceContentEvent) event;
+ if (inEvent.isLastEvent()) {
+ // end learning
+ for (Stream stream : ensembleStreams)
+ stream.put(event);
+ return false;
+ }
+
+ if (inEvent.isTesting()) {
+ Instance testInstance = inEvent.getInstance();
+ for (int i = 0; i < ensembleSize; i++) {
+ Instance instanceCopy = testInstance.copy();
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy,
+ false, true);
+ instanceContentEvent.setClassifierIndex(i); //TODO probably not needed anymore
+ instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); //TODO probably not needed anymore
+ ensembleStreams[i].put(instanceContentEvent);
+ }
+ }
+
+ // estimate model parameters using the training data
+ if (inEvent.isTraining()) {
+ train(inEvent);
+ }
+ return false;
+ }
+
+ /**
+ * Train.
+ *
+ * @param inEvent
+ * the in event
+ */
+ protected void train(InstanceContentEvent inEvent) {
+ Instance trainInst = inEvent.getInstance().copy();
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), trainInst,
+ true, false);
+ int i = random.nextInt(ensembleSize);
+ instanceContentEvent.setClassifierIndex(i);
+ instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ ensembleStreams[i].put(instanceContentEvent);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.ProcessingElement#onCreate()
+ */
+ @Override
+ public void onCreate(int id) {
+ // do nothing
+ }
+
+ public Stream[] getOutputStreams() {
+ return ensembleStreams;
+ }
+
+ public void setOutputStreams(Stream[] ensembleStreams) {
+ this.ensembleStreams = ensembleStreams;
+ }
+
+ /**
+ * Gets the size ensemble.
+ *
+ * @return the size ensemble
+ */
+ public int getEnsembleSize() {
+ return ensembleSize;
+ }
+
+ /**
+ * Sets the size ensemble.
+ *
+ * @param ensembleSize
+ * the new size ensemble
+ */
+ public void setEnsembleSize(int ensembleSize) {
+ this.ensembleSize = ensembleSize;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+ */
+ @Override
+ public Processor newProcessor(Processor sourceProcessor) {
+ ShardingDistributorProcessor newProcessor = new ShardingDistributorProcessor();
+ ShardingDistributorProcessor originProcessor = (ShardingDistributorProcessor) sourceProcessor;
+ if (originProcessor.getOutputStreams() != null) {
+ newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(),
+ originProcessor.getOutputStreams().length));
+ }
+ newProcessor.setEnsembleSize(originProcessor.getEnsembleSize());
+ /*
+ * if (originProcessor.getLearningCurve() != null){
+ * newProcessor.setLearningCurve((LearningCurve)
+ * originProcessor.getLearningCurve().copy()); }
+ */
+ return newProcessor;
+ }
+}