| package org.apache.samoa.learners.classifiers.trees; |
| |
| /* |
| * #%L |
| * SAMOA |
| * %% |
| * Copyright (C) 2014 - 2015 Apache Software Foundation |
| * %% |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * #L% |
| */ |
| |
| import java.util.Arrays; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Map.Entry; |
| import java.util.Vector; |
| |
| import org.apache.samoa.core.ContentEvent; |
| import org.apache.samoa.core.Processor; |
| import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; |
| import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; |
| import org.apache.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; |
| import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; |
| import org.apache.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion; |
| import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; |
| import org.apache.samoa.topology.Stream; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import com.google.common.collect.HashBasedTable; |
| import com.google.common.collect.Table; |
| |
| /** |
| * Local Statistic Processor contains the local statistic of a subset of the attributes. |
| * |
| * @author Arinto Murdopo |
| * |
| */ |
| final class LocalStatisticsProcessor implements Processor { |
| |
| /** |
| * |
| */ |
| private static final long serialVersionUID = -3967695130634517631L; |
| private static Logger logger = LoggerFactory.getLogger(LocalStatisticsProcessor.class); |
| |
| // Collection of AttributeObservers, for each ActiveLearningNode and |
| // AttributeId |
| private Table<Long, Integer, AttributeClassObserver> localStats; |
| |
| private Stream computationResultStream; |
| |
| private final SplitCriterion splitCriterion; |
| private final boolean binarySplit; |
| private final AttributeClassObserver nominalClassObserver; |
| private final AttributeClassObserver numericClassObserver; |
| |
| // the two observer classes below are also needed to be setup from the Tree |
| private LocalStatisticsProcessor(Builder builder) { |
| this.splitCriterion = builder.splitCriterion; |
| this.binarySplit = builder.binarySplit; |
| this.nominalClassObserver = builder.nominalClassObserver; |
| this.numericClassObserver = builder.numericClassObserver; |
| } |
| |
| @Override |
| public boolean process(ContentEvent event) { |
| // process AttributeContentEvent by updating the subset of local statistics |
| if (event instanceof AttributeBatchContentEvent) { |
| AttributeBatchContentEvent abce = (AttributeBatchContentEvent) event; |
| List<ContentEvent> contentEventList = abce.getContentEventList(); |
| for (ContentEvent contentEvent : contentEventList) { |
| AttributeContentEvent ace = (AttributeContentEvent) contentEvent; |
| Long learningNodeId = ace.getLearningNodeId(); |
| Integer obsIndex = ace.getObsIndex(); |
| |
| AttributeClassObserver obs = localStats.get( |
| learningNodeId, obsIndex); |
| |
| if (obs == null) { |
| obs = ace.isNominal() ? newNominalClassObserver() |
| : newNumericClassObserver(); |
| localStats.put(ace.getLearningNodeId(), obsIndex, obs); |
| } |
| obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(), |
| ace.getWeight()); |
| } |
| |
| /* |
| * if (event instanceof AttributeContentEvent) { AttributeContentEvent ace |
| * = (AttributeContentEvent) event; Long learningNodeId = |
| * Long.valueOf(ace.getLearningNodeId()); Integer obsIndex = |
| * Integer.valueOf(ace.getObsIndex()); |
| * |
| * AttributeClassObserver obs = localStats.get( learningNodeId, obsIndex); |
| * |
| * if (obs == null) { obs = ace.isNominal() ? newNominalClassObserver() : |
| * newNumericClassObserver(); localStats.put(ace.getLearningNodeId(), |
| * obsIndex, obs); } obs.observeAttributeClass(ace.getAttrVal(), |
| * ace.getClassVal(), ace.getWeight()); |
| */ |
| } else if (event instanceof ComputeContentEvent) { |
| // process ComputeContentEvent by calculating the local statistic |
| // and send back the calculation results via computation result stream. |
| ComputeContentEvent cce = (ComputeContentEvent) event; |
| Long learningNodeId = cce.getLearningNodeId(); |
| double[] preSplitDist = cce.getPreSplitDist(); |
| |
| Map<Integer, AttributeClassObserver> learningNodeRowMap = localStats |
| .row(learningNodeId); |
| List<AttributeSplitSuggestion> suggestions = new Vector<>(); |
| |
| for (Entry<Integer, AttributeClassObserver> entry : learningNodeRowMap.entrySet()) { |
| AttributeClassObserver obs = entry.getValue(); |
| AttributeSplitSuggestion suggestion = obs |
| .getBestEvaluatedSplitSuggestion(splitCriterion, |
| preSplitDist, entry.getKey(), binarySplit); |
| if (suggestion != null) { |
| suggestions.add(suggestion); |
| } |
| } |
| |
| AttributeSplitSuggestion[] bestSuggestions = suggestions |
| .toArray(new AttributeSplitSuggestion[suggestions.size()]); |
| |
| Arrays.sort(bestSuggestions); |
| |
| AttributeSplitSuggestion bestSuggestion = null; |
| AttributeSplitSuggestion secondBestSuggestion = null; |
| |
| if (bestSuggestions.length >= 1) { |
| bestSuggestion = bestSuggestions[bestSuggestions.length - 1]; |
| |
| if (bestSuggestions.length >= 2) { |
| secondBestSuggestion = bestSuggestions[bestSuggestions.length - 2]; |
| } |
| } |
| |
| // create the local result content event |
| LocalResultContentEvent lcre = |
| new LocalResultContentEvent(cce.getSplitId(), bestSuggestion, secondBestSuggestion); |
| computationResultStream.put(lcre); |
| logger.debug("Finish compute event"); |
| } else if (event instanceof DeleteContentEvent) { |
| DeleteContentEvent dce = (DeleteContentEvent) event; |
| Long learningNodeId = dce.getLearningNodeId(); |
| localStats.rowMap().remove(learningNodeId); |
| } |
| return false; |
| } |
| |
| @Override |
| public void onCreate(int id) { |
| this.localStats = HashBasedTable.create(); |
| } |
| |
| @Override |
| public Processor newProcessor(Processor p) { |
| LocalStatisticsProcessor oldProcessor = (LocalStatisticsProcessor) p; |
| LocalStatisticsProcessor newProcessor = new LocalStatisticsProcessor.Builder(oldProcessor).build(); |
| |
| newProcessor.setComputationResultStream(oldProcessor.computationResultStream); |
| |
| return newProcessor; |
| } |
| |
| /** |
| * Method to set the computation result when using this processor to build a topology. |
| * |
| * @param computeStream |
| */ |
| void setComputationResultStream(Stream computeStream) { |
| this.computationResultStream = computeStream; |
| } |
| |
| private AttributeClassObserver newNominalClassObserver() { |
| return (AttributeClassObserver) this.nominalClassObserver.copy(); |
| } |
| |
| private AttributeClassObserver newNumericClassObserver() { |
| return (AttributeClassObserver) this.numericClassObserver.copy(); |
| } |
| |
| /** |
| * Builder class to replace constructors with many parameters |
| * |
| * @author Arinto Murdopo |
| * |
| */ |
| static class Builder { |
| |
| private SplitCriterion splitCriterion = new InfoGainSplitCriterion(); |
| private boolean binarySplit = false; |
| private AttributeClassObserver nominalClassObserver = new NominalAttributeClassObserver(); |
| private AttributeClassObserver numericClassObserver = new GaussianNumericAttributeClassObserver(); |
| |
| Builder() { |
| |
| } |
| |
| Builder(LocalStatisticsProcessor oldProcessor) { |
| this.splitCriterion = oldProcessor.splitCriterion; |
| this.binarySplit = oldProcessor.binarySplit; |
| } |
| |
| Builder splitCriterion(SplitCriterion splitCriterion) { |
| this.splitCriterion = splitCriterion; |
| return this; |
| } |
| |
| Builder binarySplit(boolean binarySplit) { |
| this.binarySplit = binarySplit; |
| return this; |
| } |
| |
| Builder nominalClassObserver(AttributeClassObserver nominalClassObserver) { |
| this.nominalClassObserver = nominalClassObserver; |
| return this; |
| } |
| |
| Builder numericClassObserver(AttributeClassObserver numericClassObserver) { |
| this.numericClassObserver = numericClassObserver; |
| return this; |
| } |
| |
| LocalStatisticsProcessor build() { |
| return new LocalStatisticsProcessor(this); |
| } |
| } |
| |
| } |