blob: 7ce46ecff478642801f37a6d5f7f1da0a67ee43e [file] [log] [blame]
package org.apache.samoa.learners.classifiers.trees;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Vector;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
/**
* Local Statistic Processor contains the local statistic of a subset of the attributes.
*
* @author Arinto Murdopo
*
*/
final class LocalStatisticsProcessor implements Processor {
/**
*
*/
private static final long serialVersionUID = -3967695130634517631L;
private static Logger logger = LoggerFactory.getLogger(LocalStatisticsProcessor.class);
// Collection of AttributeObservers, for each ActiveLearningNode and
// AttributeId
private Table<Long, Integer, AttributeClassObserver> localStats;
private Stream computationResultStream;
private final SplitCriterion splitCriterion;
private final boolean binarySplit;
private final AttributeClassObserver nominalClassObserver;
private final AttributeClassObserver numericClassObserver;
// the two observer classes below are also needed to be setup from the Tree
private LocalStatisticsProcessor(Builder builder) {
this.splitCriterion = builder.splitCriterion;
this.binarySplit = builder.binarySplit;
this.nominalClassObserver = builder.nominalClassObserver;
this.numericClassObserver = builder.numericClassObserver;
}
@Override
public boolean process(ContentEvent event) {
// process AttributeContentEvent by updating the subset of local statistics
if (event instanceof AttributeBatchContentEvent) {
AttributeBatchContentEvent abce = (AttributeBatchContentEvent) event;
List<ContentEvent> contentEventList = abce.getContentEventList();
for (ContentEvent contentEvent : contentEventList) {
AttributeContentEvent ace = (AttributeContentEvent) contentEvent;
Long learningNodeId = ace.getLearningNodeId();
Integer obsIndex = ace.getObsIndex();
AttributeClassObserver obs = localStats.get(
learningNodeId, obsIndex);
if (obs == null) {
obs = ace.isNominal() ? newNominalClassObserver()
: newNumericClassObserver();
localStats.put(ace.getLearningNodeId(), obsIndex, obs);
}
obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(),
ace.getWeight());
}
/*
* if (event instanceof AttributeContentEvent) { AttributeContentEvent ace
* = (AttributeContentEvent) event; Long learningNodeId =
* Long.valueOf(ace.getLearningNodeId()); Integer obsIndex =
* Integer.valueOf(ace.getObsIndex());
*
* AttributeClassObserver obs = localStats.get( learningNodeId, obsIndex);
*
* if (obs == null) { obs = ace.isNominal() ? newNominalClassObserver() :
* newNumericClassObserver(); localStats.put(ace.getLearningNodeId(),
* obsIndex, obs); } obs.observeAttributeClass(ace.getAttrVal(),
* ace.getClassVal(), ace.getWeight());
*/
} else if (event instanceof ComputeContentEvent) {
// process ComputeContentEvent by calculating the local statistic
// and send back the calculation results via computation result stream.
ComputeContentEvent cce = (ComputeContentEvent) event;
Long learningNodeId = cce.getLearningNodeId();
double[] preSplitDist = cce.getPreSplitDist();
Map<Integer, AttributeClassObserver> learningNodeRowMap = localStats
.row(learningNodeId);
List<AttributeSplitSuggestion> suggestions = new Vector<>();
for (Entry<Integer, AttributeClassObserver> entry : learningNodeRowMap.entrySet()) {
AttributeClassObserver obs = entry.getValue();
AttributeSplitSuggestion suggestion = obs
.getBestEvaluatedSplitSuggestion(splitCriterion,
preSplitDist, entry.getKey(), binarySplit);
if (suggestion != null) {
suggestions.add(suggestion);
}
}
AttributeSplitSuggestion[] bestSuggestions = suggestions
.toArray(new AttributeSplitSuggestion[suggestions.size()]);
Arrays.sort(bestSuggestions);
AttributeSplitSuggestion bestSuggestion = null;
AttributeSplitSuggestion secondBestSuggestion = null;
if (bestSuggestions.length >= 1) {
bestSuggestion = bestSuggestions[bestSuggestions.length - 1];
if (bestSuggestions.length >= 2) {
secondBestSuggestion = bestSuggestions[bestSuggestions.length - 2];
}
}
// create the local result content event
LocalResultContentEvent lcre =
new LocalResultContentEvent(cce.getSplitId(), bestSuggestion, secondBestSuggestion);
computationResultStream.put(lcre);
logger.debug("Finish compute event");
} else if (event instanceof DeleteContentEvent) {
DeleteContentEvent dce = (DeleteContentEvent) event;
Long learningNodeId = dce.getLearningNodeId();
localStats.rowMap().remove(learningNodeId);
}
return false;
}
@Override
public void onCreate(int id) {
this.localStats = HashBasedTable.create();
}
@Override
public Processor newProcessor(Processor p) {
LocalStatisticsProcessor oldProcessor = (LocalStatisticsProcessor) p;
LocalStatisticsProcessor newProcessor = new LocalStatisticsProcessor.Builder(oldProcessor).build();
newProcessor.setComputationResultStream(oldProcessor.computationResultStream);
return newProcessor;
}
/**
* Method to set the computation result when using this processor to build a topology.
*
* @param computeStream
*/
void setComputationResultStream(Stream computeStream) {
this.computationResultStream = computeStream;
}
private AttributeClassObserver newNominalClassObserver() {
return (AttributeClassObserver) this.nominalClassObserver.copy();
}
private AttributeClassObserver newNumericClassObserver() {
return (AttributeClassObserver) this.numericClassObserver.copy();
}
/**
* Builder class to replace constructors with many parameters
*
* @author Arinto Murdopo
*
*/
static class Builder {
private SplitCriterion splitCriterion = new InfoGainSplitCriterion();
private boolean binarySplit = false;
private AttributeClassObserver nominalClassObserver = new NominalAttributeClassObserver();
private AttributeClassObserver numericClassObserver = new GaussianNumericAttributeClassObserver();
Builder() {
}
Builder(LocalStatisticsProcessor oldProcessor) {
this.splitCriterion = oldProcessor.splitCriterion;
this.binarySplit = oldProcessor.binarySplit;
}
Builder splitCriterion(SplitCriterion splitCriterion) {
this.splitCriterion = splitCriterion;
return this;
}
Builder binarySplit(boolean binarySplit) {
this.binarySplit = binarySplit;
return this;
}
Builder nominalClassObserver(AttributeClassObserver nominalClassObserver) {
this.nominalClassObserver = nominalClassObserver;
return this;
}
Builder numericClassObserver(AttributeClassObserver numericClassObserver) {
this.numericClassObserver = numericClassObserver;
return this;
}
LocalStatisticsProcessor build() {
return new LocalStatisticsProcessor(this);
}
}
}