| package com.yahoo.labs.samoa.learners.classifiers.ensemble; |
| |
| /* |
| * #%L |
| * SAMOA |
| * %% |
| * Copyright (C) 2013 Yahoo! Inc. |
| * %% |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * #L% |
| */ |
| |
| /** |
| * License |
| */ |
| import java.util.HashMap; |
| import java.util.Map; |
| |
| import com.yahoo.labs.samoa.core.ContentEvent; |
| import com.yahoo.labs.samoa.core.Processor; |
| import com.yahoo.labs.samoa.learners.ResultContentEvent; |
| import com.yahoo.labs.samoa.moa.core.DoubleVector; |
| import com.yahoo.labs.samoa.topology.Stream; |
| |
| /** |
| * The Class PredictionCombinerProcessor. |
| */ |
| public class PredictionCombinerProcessor implements Processor { |
| |
| private static final long serialVersionUID = -1606045723451191132L; |
| |
| /** |
| * The size ensemble. |
| */ |
| protected int ensembleSize; |
| |
| /** |
| * The output stream. |
| */ |
| protected Stream outputStream; |
| |
| /** |
| * Sets the output stream. |
| * |
| * @param stream |
| * the new output stream |
| */ |
| public void setOutputStream(Stream stream) { |
| outputStream = stream; |
| } |
| |
| /** |
| * Gets the output stream. |
| * |
| * @return the output stream |
| */ |
| public Stream getOutputStream() { |
| return outputStream; |
| } |
| |
| /** |
| * Gets the size ensemble. |
| * |
| * @return the ensembleSize |
| */ |
| public int getSizeEnsemble() { |
| return ensembleSize; |
| } |
| |
| /** |
| * Sets the size ensemble. |
| * |
| * @param ensembleSize |
| * the new size ensemble |
| */ |
| public void setSizeEnsemble(int ensembleSize) { |
| this.ensembleSize = ensembleSize; |
| } |
| |
| protected Map<Integer, Integer> mapCountsforInstanceReceived; |
| |
| protected Map<Integer, DoubleVector> mapVotesforInstanceReceived; |
| |
| /** |
| * On event. |
| * |
| * @param event |
| * the event |
| * @return true, if successful |
| */ |
| public boolean process(ContentEvent event) { |
| |
| ResultContentEvent inEvent = (ResultContentEvent) event; |
| double[] prediction = inEvent.getClassVotes(); |
| int instanceIndex = (int) inEvent.getInstanceIndex(); |
| |
| addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); |
| |
| if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) { |
| DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); |
| if (combinedVote == null) { |
| combinedVote = new DoubleVector(new double[inEvent.getInstance().numClasses()]); |
| } |
| ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), |
| inEvent.getInstance(), inEvent.getClassId(), |
| combinedVote.getArrayCopy(), inEvent.isLastEvent()); |
| outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); |
| outputStream.put(outContentEvent); |
| clearStatisticsInstance(instanceIndex); |
| return true; |
| } |
| return false; |
| |
| } |
| |
| @Override |
| public void onCreate(int id) { |
| this.reset(); |
| } |
| |
| public void reset() { |
| } |
| |
| /* |
| * (non-Javadoc) |
| * |
| * @see samoa.core.Processor#newProcessor(samoa.core.Processor) |
| */ |
| @Override |
| public Processor newProcessor(Processor sourceProcessor) { |
| PredictionCombinerProcessor newProcessor = new PredictionCombinerProcessor(); |
| PredictionCombinerProcessor originProcessor = (PredictionCombinerProcessor) sourceProcessor; |
| if (originProcessor.getOutputStream() != null) { |
| newProcessor.setOutputStream(originProcessor.getOutputStream()); |
| } |
| newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble()); |
| return newProcessor; |
| } |
| |
| protected void addStatisticsForInstanceReceived(int instanceIndex, int classifierIndex, double[] prediction, int add) { |
| if (this.mapCountsforInstanceReceived == null) { |
| this.mapCountsforInstanceReceived = new HashMap<>(); |
| this.mapVotesforInstanceReceived = new HashMap<>(); |
| } |
| DoubleVector vote = new DoubleVector(prediction); |
| if (vote.sumOfValues() > 0.0) { |
| vote.normalize(); |
| DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); |
| if (combinedVote == null) { |
| combinedVote = new DoubleVector(); |
| } |
| vote.scaleValues(getEnsembleMemberWeight(classifierIndex)); |
| combinedVote.addValues(vote); |
| |
| this.mapVotesforInstanceReceived.put(instanceIndex, combinedVote); |
| } |
| Integer count = this.mapCountsforInstanceReceived.get(instanceIndex); |
| if (count == null) { |
| count = 0; |
| } |
| this.mapCountsforInstanceReceived.put(instanceIndex, count + add); |
| } |
| |
| protected boolean hasAllVotesArrivedInstance(int instanceIndex) { |
| return (this.mapCountsforInstanceReceived.get(instanceIndex) == this.ensembleSize); |
| } |
| |
| protected void clearStatisticsInstance(int instanceIndex) { |
| this.mapCountsforInstanceReceived.remove(instanceIndex); |
| this.mapVotesforInstanceReceived.remove(instanceIndex); |
| } |
| |
| protected double getEnsembleMemberWeight(int i) { |
| return 1.0; |
| } |
| |
| } |