blob: 2131db49ef2ba6fbaf3a05862e10c0a07239deaf [file] [log] [blame]
package org.apache.samoa.learners.classifiers.rules.distributed;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.learners.InstanceContentEvent;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.learners.classifiers.rules.common.ActiveRule;
import org.apache.samoa.learners.classifiers.rules.common.LearningRule;
import org.apache.samoa.learners.classifiers.rules.common.PassiveRule;
import org.apache.samoa.learners.classifiers.rules.common.Perceptron;
import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode;
import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver;
import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote;
import org.apache.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote;
import org.apache.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote;
import org.apache.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Model Aggregator Processor (VAMR).
*
* @author Anh Thu Vu
*
*/
public class AMRulesAggregatorProcessor implements Processor {
/**
*
*/
private static final long serialVersionUID = 6303385725332704251L;
private static final Logger logger =
LoggerFactory.getLogger(AMRulesAggregatorProcessor.class);
private int processorId;
// Rules & default rule
protected transient List<PassiveRule> ruleSet;
protected transient ActiveRule defaultRule;
protected transient int ruleNumberID;
protected transient double[] statistics;
// SAMOA Stream
private Stream statisticsStream;
private Stream resultStream;
// Options
protected int pageHinckleyThreshold;
protected double pageHinckleyAlpha;
protected boolean driftDetection;
protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
protected boolean constantLearningRatioDecay;
protected double learningRatio;
protected double splitConfidence;
protected double tieThreshold;
protected int gracePeriod;
protected boolean noAnomalyDetection;
protected double multivariateAnomalyProbabilityThreshold;
protected double univariateAnomalyprobabilityThreshold;
protected int anomalyNumInstThreshold;
protected boolean unorderedRules;
protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
protected int voteType;
/*
* Constructor
*/
public AMRulesAggregatorProcessor(Builder builder) {
this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
this.driftDetection = builder.driftDetection;
this.predictionFunction = builder.predictionFunction;
this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
this.learningRatio = builder.learningRatio;
this.splitConfidence = builder.splitConfidence;
this.tieThreshold = builder.tieThreshold;
this.gracePeriod = builder.gracePeriod;
this.noAnomalyDetection = builder.noAnomalyDetection;
this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
this.unorderedRules = builder.unorderedRules;
this.numericObserver = builder.numericObserver;
this.voteType = builder.voteType;
}
/*
* Process
*/
@Override
public boolean process(ContentEvent event) {
if (event instanceof InstanceContentEvent) {
InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
this.processInstanceEvent(instanceEvent);
}
else if (event instanceof PredicateContentEvent) {
this.updateRuleSplitNode((PredicateContentEvent) event);
}
else if (event instanceof RuleContentEvent) {
RuleContentEvent rce = (RuleContentEvent) event;
if (rce.isRemoving()) {
this.removeRule(rce.getRuleNumberID());
}
}
return true;
}
// Merge predict and train so we only check for covering rules one time
private void processInstanceEvent(InstanceContentEvent instanceEvent) {
Instance instance = instanceEvent.getInstance();
boolean predictionCovered = false;
boolean trainingCovered = false;
boolean continuePrediction = instanceEvent.isTesting();
boolean continueTraining = instanceEvent.isTraining();
ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
Iterator<PassiveRule> ruleIterator = this.ruleSet.iterator();
while (ruleIterator.hasNext()) {
if (!continuePrediction && !continueTraining)
break;
PassiveRule rule = ruleIterator.next();
if (rule.isCovering(instance) == true) {
predictionCovered = true;
if (continuePrediction) {
double[] vote = rule.getPrediction(instance);
double error = rule.getCurrentError();
errorWeightedVote.addVote(vote, error);
if (!this.unorderedRules)
continuePrediction = false;
}
if (continueTraining) {
if (!isAnomaly(instance, rule)) {
trainingCovered = true;
rule.updateStatistics(instance);
// Send instance to statistics PIs
sendInstanceToRule(instance, rule.getRuleNumberID());
if (!this.unorderedRules)
continueTraining = false;
}
}
}
}
if (predictionCovered) {
// Combined prediction
ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent);
resultStream.put(rce);
}
else if (instanceEvent.isTesting()) {
// predict with default rule
double[] vote = defaultRule.getPrediction(instance);
ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
resultStream.put(rce);
}
if (!trainingCovered && instanceEvent.isTraining()) {
// train default rule with this instance
defaultRule.updateStatistics(instance);
if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
(RuleActiveRegressionNode) defaultRule.getLearningNode(),
((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch
defaultRule.split();
defaultRule.setRuleNumberID(++ruleNumberID);
this.ruleSet.add(new PassiveRule(this.defaultRule));
// send to statistics PI
sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
defaultRule = newDefaultRule;
}
}
}
}
/**
* Helper method to generate new ResultContentEvent based on an instance and its prediction result.
*
* @param prediction
* The predicted class label from the decision tree model.
* @param inEvent
* The associated instance content event
* @return ResultContentEvent to be sent into Evaluator PI or other destination PI.
*/
private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
inEvent.getClassId(), prediction, inEvent.isLastEvent());
rce.setClassifierIndex(this.processorId);
rce.setEvaluationIndex(inEvent.getEvaluationIndex());
return rce;
}
public ErrorWeightedVote newErrorWeightedVote() {
if (voteType == 1)
return new UniformWeightedVote();
return new InverseErrorWeightedVote();
}
/**
* Method to verify if the instance is an anomaly.
*
* @param instance
* @param rule
* @return
*/
private boolean isAnomaly(Instance instance, LearningRule rule) {
// AMRUles is equipped with anomaly detection. If on, compute the anomaly
// value.
boolean isAnomaly = false;
if (this.noAnomalyDetection == false) {
if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
isAnomaly = rule.isAnomaly(instance,
this.univariateAnomalyprobabilityThreshold,
this.multivariateAnomalyProbabilityThreshold,
this.anomalyNumInstThreshold);
}
}
return isAnomaly;
}
/*
* Create new rules
*/
private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
ActiveRule r = newRule(ID);
if (node != null)
{
if (node.getPerceptron() != null)
{
r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
}
if (statistics == null)
{
double mean;
if (node.getNodeStatistics().getValue(0) > 0) {
mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
r.getLearningNode().getTargetMean().reset(mean, 1);
}
}
}
if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
{
double mean;
if (statistics[0] > 0) {
mean = statistics[1] / statistics[0];
((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
}
}
return r;
}
private ActiveRule newRule(int ID) {
ActiveRule r = new ActiveRule.Builder().
threshold(this.pageHinckleyThreshold).
alpha(this.pageHinckleyAlpha).
changeDetection(this.driftDetection).
predictionFunction(this.predictionFunction).
statistics(new double[3]).
learningRatio(this.learningRatio).
numericObserver(numericObserver).
id(ID).build();
return r;
}
/*
* Add predicate/RuleSplitNode for a rule
*/
private void updateRuleSplitNode(PredicateContentEvent pce) {
int ruleID = pce.getRuleNumberID();
for (PassiveRule rule : ruleSet) {
if (rule.getRuleNumberID() == ruleID) {
if (pce.getRuleSplitNode() != null)
rule.nodeListAdd(pce.getRuleSplitNode());
if (pce.getLearningNode() != null)
rule.setLearningNode(pce.getLearningNode());
}
}
}
/*
* Remove rule
*/
private void removeRule(int ruleID) {
for (PassiveRule rule : ruleSet) {
if (rule.getRuleNumberID() == ruleID) {
ruleSet.remove(rule);
break;
}
}
}
@Override
public void onCreate(int id) {
this.processorId = id;
this.statistics = new double[] { 0.0, 0, 0 };
this.ruleNumberID = 0;
this.defaultRule = newRule(++this.ruleNumberID);
this.ruleSet = new LinkedList<PassiveRule>();
}
/*
* Clone processor
*/
@Override
public Processor newProcessor(Processor p) {
AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p;
Builder builder = new Builder(oldProcessor);
AMRulesAggregatorProcessor newProcessor = builder.build();
newProcessor.resultStream = oldProcessor.resultStream;
newProcessor.statisticsStream = oldProcessor.statisticsStream;
return newProcessor;
}
/*
* Send events
*/
private void sendInstanceToRule(Instance instance, int ruleID) {
AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance);
this.statisticsStream.put(ace);
}
private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
this.statisticsStream.put(rce);
}
/*
* Output streams
*/
public void setStatisticsStream(Stream statisticsStream) {
this.statisticsStream = statisticsStream;
}
public Stream getStatisticsStream() {
return this.statisticsStream;
}
public void setResultStream(Stream resultStream) {
this.resultStream = resultStream;
}
public Stream getResultStream() {
return this.resultStream;
}
/*
* Others
*/
public boolean isRandomizable() {
return true;
}
/*
* Builder
*/
public static class Builder {
private int pageHinckleyThreshold;
private double pageHinckleyAlpha;
private boolean driftDetection;
private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
private boolean constantLearningRatioDecay;
private double learningRatio;
private double splitConfidence;
private double tieThreshold;
private int gracePeriod;
private boolean noAnomalyDetection;
private double multivariateAnomalyProbabilityThreshold;
private double univariateAnomalyprobabilityThreshold;
private int anomalyNumInstThreshold;
private boolean unorderedRules;
private FIMTDDNumericAttributeClassLimitObserver numericObserver;
private int voteType;
private Instances dataset;
public Builder(Instances dataset) {
this.dataset = dataset;
}
public Builder(AMRulesAggregatorProcessor processor) {
this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
this.driftDetection = processor.driftDetection;
this.predictionFunction = processor.predictionFunction;
this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
this.learningRatio = processor.learningRatio;
this.splitConfidence = processor.splitConfidence;
this.tieThreshold = processor.tieThreshold;
this.gracePeriod = processor.gracePeriod;
this.noAnomalyDetection = processor.noAnomalyDetection;
this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
this.unorderedRules = processor.unorderedRules;
this.numericObserver = processor.numericObserver;
this.voteType = processor.voteType;
}
public Builder threshold(int threshold) {
this.pageHinckleyThreshold = threshold;
return this;
}
public Builder alpha(double alpha) {
this.pageHinckleyAlpha = alpha;
return this;
}
public Builder changeDetection(boolean changeDetection) {
this.driftDetection = changeDetection;
return this;
}
public Builder predictionFunction(int predictionFunction) {
this.predictionFunction = predictionFunction;
return this;
}
public Builder constantLearningRatioDecay(boolean constantDecay) {
this.constantLearningRatioDecay = constantDecay;
return this;
}
public Builder learningRatio(double learningRatio) {
this.learningRatio = learningRatio;
return this;
}
public Builder splitConfidence(double splitConfidence) {
this.splitConfidence = splitConfidence;
return this;
}
public Builder tieThreshold(double tieThreshold) {
this.tieThreshold = tieThreshold;
return this;
}
public Builder gracePeriod(int gracePeriod) {
this.gracePeriod = gracePeriod;
return this;
}
public Builder noAnomalyDetection(boolean noAnomalyDetection) {
this.noAnomalyDetection = noAnomalyDetection;
return this;
}
public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
return this;
}
public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
return this;
}
public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
this.anomalyNumInstThreshold = anomalyNumInstThreshold;
return this;
}
public Builder unorderedRules(boolean unorderedRules) {
this.unorderedRules = unorderedRules;
return this;
}
public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
this.numericObserver = numericObserver;
return this;
}
public Builder voteType(int voteType) {
this.voteType = voteType;
return this;
}
public AMRulesAggregatorProcessor build() {
return new AMRulesAggregatorProcessor(this);
}
}
}