blob: 8e1009441699235ed0920fe3e0b2de880ee44997 [file] [log] [blame]
package org.apache.samoa.learners.classifiers.ensemble;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
/**
* License
*/
import java.util.HashMap;
import java.util.Map;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.moa.core.DoubleVector;
import org.apache.samoa.topology.Stream;
/**
* Combines predictions coming from an ensemble. Equivalent to a majority-vote classifier.
*/
public class PredictionCombinerProcessor implements Processor {
private static final long serialVersionUID = -1606045723451191132L;
/**
* The ensemble size.
*/
protected int ensembleSize;
/**
* The output stream.
*/
protected Stream outputStream;
/**
* Sets the output stream.
*
* @param stream
* the new output stream
*/
public void setOutputStream(Stream stream) {
outputStream = stream;
}
/**
* Gets the output stream.
*
* @return the output stream
*/
public Stream getOutputStream() {
return outputStream;
}
/**
* Gets the size ensemble.
*
* @return the ensembleSize
*/
public int getEnsembleSize() {
return ensembleSize;
}
/**
* Sets the size ensemble.
*
* @param ensembleSize
* the new size ensemble
*/
public void setEnsembleSize(int ensembleSize) {
this.ensembleSize = ensembleSize;
}
protected Map<Integer, Integer> mapCountsforInstanceReceived;
protected Map<Integer, DoubleVector> mapVotesforInstanceReceived;
/**
* On event.
*
* @param event
* the event
* @return true, if successful
*/
public boolean process(ContentEvent event) {
ResultContentEvent inEvent = (ResultContentEvent) event;
double[] prediction = inEvent.getClassVotes();
int instanceIndex = (int) inEvent.getInstanceIndex();
addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1);
if (hasAllVotesArrivedInstance(instanceIndex)) {
DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
if (combinedVote == null) {
combinedVote = new DoubleVector(new double[inEvent.getInstance().numClasses()]);
}
ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
inEvent.getClassId(), combinedVote.getArrayCopy(), inEvent.isLastEvent());
outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
outputStream.put(outContentEvent);
clearStatisticsInstance(instanceIndex);
return true;
}
return false;
}
@Override
public void onCreate(int id) {
this.reset();
}
public void reset() {
}
/*
* (non-Javadoc)
* @see samoa.core.Processor#newProcessor(samoa.core.Processor)
*/
@Override
public Processor newProcessor(Processor sourceProcessor) {
PredictionCombinerProcessor newProcessor = new PredictionCombinerProcessor();
PredictionCombinerProcessor originProcessor = (PredictionCombinerProcessor) sourceProcessor;
if (originProcessor.getOutputStream() != null) {
newProcessor.setOutputStream(originProcessor.getOutputStream());
}
newProcessor.setEnsembleSize(originProcessor.getEnsembleSize());
return newProcessor;
}
protected void addStatisticsForInstanceReceived(int instanceIndex, int classifierIndex, double[] prediction, int add) {
if (this.mapCountsforInstanceReceived == null) {
this.mapCountsforInstanceReceived = new HashMap<>();
this.mapVotesforInstanceReceived = new HashMap<>();
}
DoubleVector vote = new DoubleVector(prediction);
if (vote.sumOfValues() > 0.0) {
vote.normalize();
DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
if (combinedVote == null) {
combinedVote = new DoubleVector();
}
vote.scaleValues(getEnsembleMemberWeight(classifierIndex));
combinedVote.addValues(vote);
this.mapVotesforInstanceReceived.put(instanceIndex, combinedVote);
}
Integer count = this.mapCountsforInstanceReceived.get(instanceIndex);
if (count == null) {
count = 0;
}
this.mapCountsforInstanceReceived.put(instanceIndex, count + add);
}
protected boolean hasAllVotesArrivedInstance(int instanceIndex) {
return (this.mapCountsforInstanceReceived.get(instanceIndex) == this.ensembleSize);
}
protected void clearStatisticsInstance(int instanceIndex) {
this.mapCountsforInstanceReceived.remove(instanceIndex);
this.mapVotesforInstanceReceived.remove(instanceIndex);
}
protected double getEnsembleMemberWeight(int i) {
return 1.0;
}
}