blob: 9f2b9c20de151d93adff0a2a040f992e349d7219 [file] [log] [blame]
package org.apache.samoa.learners.classifiers.rules.distributed;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.learners.InstanceContentEvent;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.learners.classifiers.rules.common.ActiveRule;
import org.apache.samoa.learners.classifiers.rules.common.Perceptron;
import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode;
import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver;
import org.apache.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Default Rule Learner Processor (HAMR).
*
* @author Anh Thu Vu
*
*/
public class AMRDefaultRuleProcessor implements Processor {
/**
*
*/
private static final long serialVersionUID = 23702084591044447L;
private static final Logger logger =
LoggerFactory.getLogger(AMRDefaultRuleProcessor.class);
private int processorId;
// Default rule
protected transient ActiveRule defaultRule;
protected transient int ruleNumberID;
protected transient double[] statistics;
// SAMOA Stream
private Stream ruleStream;
private Stream resultStream;
// Options
protected int pageHinckleyThreshold;
protected double pageHinckleyAlpha;
protected boolean driftDetection;
protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
protected boolean constantLearningRatioDecay;
protected double learningRatio;
protected double splitConfidence;
protected double tieThreshold;
protected int gracePeriod;
protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
/*
* Constructor
*/
public AMRDefaultRuleProcessor(Builder builder) {
this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
this.driftDetection = builder.driftDetection;
this.predictionFunction = builder.predictionFunction;
this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
this.learningRatio = builder.learningRatio;
this.splitConfidence = builder.splitConfidence;
this.tieThreshold = builder.tieThreshold;
this.gracePeriod = builder.gracePeriod;
this.numericObserver = builder.numericObserver;
}
@Override
public boolean process(ContentEvent event) {
InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
// predict
if (instanceEvent.isTesting()) {
this.predictOnInstance(instanceEvent);
}
// train
if (instanceEvent.isTraining()) {
this.trainOnInstance(instanceEvent);
}
return false;
}
/*
* Prediction
*/
private void predictOnInstance(InstanceContentEvent instanceEvent) {
double[] vote = defaultRule.getPrediction(instanceEvent.getInstance());
ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
resultStream.put(rce);
}
private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
inEvent.getClassId(), prediction, inEvent.isLastEvent());
rce.setClassifierIndex(this.processorId);
rce.setEvaluationIndex(inEvent.getEvaluationIndex());
return rce;
}
/*
* Training
*/
private void trainOnInstance(InstanceContentEvent instanceEvent) {
this.trainOnInstanceImpl(instanceEvent.getInstance());
}
public void trainOnInstanceImpl(Instance instance) {
defaultRule.updateStatistics(instance);
if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
(RuleActiveRegressionNode) defaultRule.getLearningNode(),
((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch
defaultRule.split();
defaultRule.setRuleNumberID(++ruleNumberID);
// send out the new rule
sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
defaultRule = newDefaultRule;
}
}
}
/*
* Create new rules
*/
private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
ActiveRule r = newRule(ID);
if (node != null)
{
if (node.getPerceptron() != null)
{
r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
}
if (statistics == null)
{
double mean;
if (node.getNodeStatistics().getValue(0) > 0) {
mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
r.getLearningNode().getTargetMean().reset(mean, 1);
}
}
}
if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
{
double mean;
if (statistics[0] > 0) {
mean = statistics[1] / statistics[0];
((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
}
}
return r;
}
private ActiveRule newRule(int ID) {
ActiveRule r = new ActiveRule.Builder().
threshold(this.pageHinckleyThreshold).
alpha(this.pageHinckleyAlpha).
changeDetection(this.driftDetection).
predictionFunction(this.predictionFunction).
statistics(new double[3]).
learningRatio(this.learningRatio).
numericObserver(numericObserver).
id(ID).build();
return r;
}
@Override
public void onCreate(int id) {
this.processorId = id;
this.statistics = new double[] { 0.0, 0, 0 };
this.ruleNumberID = 0;
this.defaultRule = newRule(++this.ruleNumberID);
}
/*
* Clone processor
*/
@Override
public Processor newProcessor(Processor p) {
AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p;
Builder builder = new Builder(oldProcessor);
AMRDefaultRuleProcessor newProcessor = builder.build();
newProcessor.resultStream = oldProcessor.resultStream;
newProcessor.ruleStream = oldProcessor.ruleStream;
return newProcessor;
}
/*
* Send events
*/
private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
this.ruleStream.put(rce);
}
/*
* Output streams
*/
public void setRuleStream(Stream ruleStream) {
this.ruleStream = ruleStream;
}
public Stream getRuleStream() {
return this.ruleStream;
}
public void setResultStream(Stream resultStream) {
this.resultStream = resultStream;
}
public Stream getResultStream() {
return this.resultStream;
}
/*
* Builder
*/
public static class Builder {
private int pageHinckleyThreshold;
private double pageHinckleyAlpha;
private boolean driftDetection;
private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
private boolean constantLearningRatioDecay;
private double learningRatio;
private double splitConfidence;
private double tieThreshold;
private int gracePeriod;
private FIMTDDNumericAttributeClassLimitObserver numericObserver;
private Instances dataset;
public Builder(Instances dataset) {
this.dataset = dataset;
}
public Builder(AMRDefaultRuleProcessor processor) {
this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
this.driftDetection = processor.driftDetection;
this.predictionFunction = processor.predictionFunction;
this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
this.learningRatio = processor.learningRatio;
this.splitConfidence = processor.splitConfidence;
this.tieThreshold = processor.tieThreshold;
this.gracePeriod = processor.gracePeriod;
this.numericObserver = processor.numericObserver;
}
public Builder threshold(int threshold) {
this.pageHinckleyThreshold = threshold;
return this;
}
public Builder alpha(double alpha) {
this.pageHinckleyAlpha = alpha;
return this;
}
public Builder changeDetection(boolean changeDetection) {
this.driftDetection = changeDetection;
return this;
}
public Builder predictionFunction(int predictionFunction) {
this.predictionFunction = predictionFunction;
return this;
}
public Builder constantLearningRatioDecay(boolean constantDecay) {
this.constantLearningRatioDecay = constantDecay;
return this;
}
public Builder learningRatio(double learningRatio) {
this.learningRatio = learningRatio;
return this;
}
public Builder splitConfidence(double splitConfidence) {
this.splitConfidence = splitConfidence;
return this;
}
public Builder tieThreshold(double tieThreshold) {
this.tieThreshold = tieThreshold;
return this;
}
public Builder gracePeriod(int gracePeriod) {
this.gracePeriod = gracePeriod;
return this;
}
public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
this.numericObserver = numericObserver;
return this;
}
public AMRDefaultRuleProcessor build() {
return new AMRDefaultRuleProcessor(this);
}
}
}