blob: fde2300511bfaa18fbfbe15108f0164248e38600 [file] [log] [blame]
package org.apache.samoa.learners.classifiers.trees;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
/**
* Local Statistic Processor contains the local statistic of a subset of the attributes.
*
* @author Arinto Murdopo
*
*/
public final class LocalStatisticsProcessor implements Processor {
/**
*
*/
private static final long serialVersionUID = -3967695130634517631L;
private static Logger logger = LoggerFactory.getLogger(LocalStatisticsProcessor.class);
// Collection of AttributeObservers, for each ActiveLearningNode and
// AttributeId
private Table<Long, Integer, AttributeClassObserver> localStats;
private Stream computationResultStream;
private final SplitCriterion splitCriterion;
private final boolean binarySplit;
private final AttributeClassObserver nominalClassObserver;
private final AttributeClassObserver numericClassObserver;
private int id;
// the two observer classes below are also needed to be setup from the Tree
private LocalStatisticsProcessor(Builder builder) {
this.splitCriterion = builder.splitCriterion;
this.binarySplit = builder.binarySplit;
this.nominalClassObserver = builder.nominalClassObserver;
this.numericClassObserver = builder.numericClassObserver;
}
@Override
public boolean process(ContentEvent event) {
// process AttributeContentEvent by updating the subset of local statistics
if (event instanceof AttributeBatchContentEvent) {
AttributeBatchContentEvent abce = (AttributeBatchContentEvent) event;
List<ContentEvent> contentEventList = abce.getContentEventList();
for (ContentEvent contentEvent : contentEventList) {
AttributeContentEvent ace = (AttributeContentEvent) contentEvent;
Long learningNodeId = ace.getLearningNodeId();
Integer obsIndex = ace.getObsIndex();
AttributeClassObserver obs = localStats.get(
learningNodeId, obsIndex);
if (obs == null) {
obs = ace.isNominal() ? newNominalClassObserver()
: newNumericClassObserver();
localStats.put(ace.getLearningNodeId(), obsIndex, obs);
}
obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(),
ace.getWeight());
}
/*
* if (event instanceof AttributeContentEvent) { AttributeContentEvent ace
* = (AttributeContentEvent) event; Long learningNodeId =
* Long.valueOf(ace.getLearningNodeId()); Integer obsIndex =
* Integer.valueOf(ace.getObsIndex());
*
* AttributeClassObserver obs = localStats.get( learningNodeId, obsIndex);
*
* if (obs == null) { obs = ace.isNominal() ? newNominalClassObserver() :
* newNumericClassObserver(); localStats.put(ace.getLearningNodeId(),
* obsIndex, obs); } obs.observeAttributeClass(ace.getAttrVal(),
* ace.getClassVal(), ace.getWeight());
*/
} else if (event instanceof AttributeSliceEvent) {
AttributeSliceEvent ase = (AttributeSliceEvent) event;
processAttributeSlice(ase);
} else if (event instanceof ComputeContentEvent){
ComputeContentEvent cce = (ComputeContentEvent) event;
processComputeEvent(cce);
} else if (event instanceof DeleteContentEvent) {
DeleteContentEvent dce = (DeleteContentEvent) event;
Long learningNodeId = dce.getLearningNodeId();
localStats.rowMap().remove(learningNodeId);
}
return true;
}
private void processComputeEvent(ComputeContentEvent cce) {
Long learningNodeId = cce.getLearningNodeId();
double[] preSplitDist = cce.getPreSplitDist();
Map<Integer, AttributeClassObserver> learningNodeRowMap = localStats.row(learningNodeId);
AttributeSplitSuggestion[] suggestions = new AttributeSplitSuggestion[learningNodeRowMap.size()];
int curIndex = 0;
for (Entry<Integer, AttributeClassObserver> entry : learningNodeRowMap.entrySet()) {
AttributeClassObserver obs = entry.getValue();
AttributeSplitSuggestion suggestion = obs
.getBestEvaluatedSplitSuggestion(splitCriterion,
preSplitDist, entry.getKey(), binarySplit);
if (suggestion == null) {
suggestion = new AttributeSplitSuggestion();
}
suggestions[curIndex] = suggestion;
curIndex++;
}
// Doing this sort instead of keeping the max and second max seems faster for some reason
Arrays.sort(suggestions);
AttributeSplitSuggestion bestSuggestion = null;
AttributeSplitSuggestion secondBestSuggestion = null;
if (suggestions.length >= 1) {
bestSuggestion = suggestions[suggestions.length - 1];
if (suggestions.length >= 2) {
secondBestSuggestion = suggestions[suggestions.length - 2];
}
}
// create the local result content event
LocalResultContentEvent lcre =
new LocalResultContentEvent(cce.getSplitId(), bestSuggestion, secondBestSuggestion);
lcre.setEnsembleId(cce.getEnsembleId());
computationResultStream.put(lcre);
}
private void processAttributeSlice(AttributeSliceEvent ase) {
// System.out.printf("Event with key: %s processed by LSP: %d%n", ase.getKey(), id);
double[] attributeSlice = ase.getAttributeSlice();
boolean[] isNominal = ase.getIsNominalSlice();
int startingIndex = ase.getAttributeStartingIndex();
Long learningNodeId = ase.getLearningNodeId();
int classValue = ase.getClassValue();
double weight = ase.getWeight();
for (int i = 0; i < attributeSlice.length; i++) {
Integer obsIndex = i + startingIndex;
AttributeClassObserver obs = localStats.get(learningNodeId, obsIndex);
if (obs == null) {
obs = isNominal[i] ? newNominalClassObserver() : newNumericClassObserver();
localStats.put(learningNodeId, obsIndex, obs);
}
obs.observeAttributeClass(attributeSlice[i], classValue, weight);
}
}
@Override
public void onCreate(int id) {
this.id = id;
this.localStats = HashBasedTable.create();
}
@Override
public Processor newProcessor(Processor p) {
LocalStatisticsProcessor oldProcessor = (LocalStatisticsProcessor) p;
LocalStatisticsProcessor newProcessor = new LocalStatisticsProcessor.Builder(oldProcessor).build();
newProcessor.setComputationResultStream(oldProcessor.getComputationResultStream());
return newProcessor;
}
/**
* Method to set the computation result when using this processor to build a topology.
*
* @param computeStream
*/
public void setComputationResultStream(Stream computeStream) {
this.computationResultStream = computeStream;
}
private AttributeClassObserver newNominalClassObserver() {
return new NominalAttributeClassObserver(); //further investigate this change
}
private AttributeClassObserver newNumericClassObserver() {
return new GaussianNumericAttributeClassObserver();//further investigate this change
}
/**
* Builder class to replace constructors with many parameters
*
* @author Arinto Murdopo
*
*/
public static class Builder {
private SplitCriterion splitCriterion = new InfoGainSplitCriterion();
private boolean binarySplit = false;
private AttributeClassObserver nominalClassObserver = new NominalAttributeClassObserver();
private AttributeClassObserver numericClassObserver = new GaussianNumericAttributeClassObserver();
public Builder() {
}
public Builder(LocalStatisticsProcessor oldProcessor) {
this.splitCriterion = oldProcessor.getSplitCriterion();
this.binarySplit = oldProcessor.isBinarySplit();
this.nominalClassObserver = oldProcessor.getNominalClassObserver();
this.numericClassObserver = oldProcessor.getNumericClassObserver();
}
public Builder splitCriterion(SplitCriterion splitCriterion) {
this.splitCriterion = splitCriterion;
return this;
}
public Builder binarySplit(boolean binarySplit) {
this.binarySplit = binarySplit;
return this;
}
public Builder nominalClassObserver(AttributeClassObserver nominalClassObserver) {
this.nominalClassObserver = nominalClassObserver;
return this;
}
public Builder numericClassObserver(AttributeClassObserver numericClassObserver) {
this.numericClassObserver = numericClassObserver;
return this;
}
public LocalStatisticsProcessor build() {
return new LocalStatisticsProcessor(this);
}
}
public SplitCriterion getSplitCriterion() {
return splitCriterion;
}
public boolean isBinarySplit() {
return binarySplit;
}
public AttributeClassObserver getNominalClassObserver() {
return nominalClassObserver;
}
public AttributeClassObserver getNumericClassObserver() {
return numericClassObserver;
}
public Stream getComputationResultStream() {
return computationResultStream;
}
}