blob: 48e9dbbcc768b4a154258c14ecf2ca65f704decf [file] [log] [blame]
package com.yahoo.labs.samoa.learners.classifiers.rules.centralized;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2013 - 2014 Yahoo! Inc.
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import com.yahoo.labs.samoa.core.ContentEvent;
import com.yahoo.labs.samoa.core.Processor;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.learners.InstanceContentEvent;
import com.yahoo.labs.samoa.learners.ResultContentEvent;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.Perceptron;
import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode;
import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver;
import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote;
import com.yahoo.labs.samoa.topology.Stream;
/**
* AMRules Regressor Processor is the main (and only) processor for
* AMRulesRegressor task. It is adapted from the AMRules implementation in MOA.
*
* @author Anh Thu Vu
*
*/
public class AMRulesRegressorProcessor implements Processor {
/**
*
*/
private static final long serialVersionUID = 1L;
private int processorId;
// Rules & default rule
protected List<ActiveRule> ruleSet;
protected ActiveRule defaultRule;
protected int ruleNumberID;
protected double[] statistics;
// SAMOA Stream
private Stream resultStream;
// Options
protected int pageHinckleyThreshold;
protected double pageHinckleyAlpha;
protected boolean driftDetection;
protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
protected boolean constantLearningRatioDecay;
protected double learningRatio;
protected double splitConfidence;
protected double tieThreshold;
protected int gracePeriod;
protected boolean noAnomalyDetection;
protected double multivariateAnomalyProbabilityThreshold;
protected double univariateAnomalyprobabilityThreshold;
protected int anomalyNumInstThreshold;
protected boolean unorderedRules;
protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
protected ErrorWeightedVote voteType;
/*
* Constructor
*/
public AMRulesRegressorProcessor(Builder builder) {
this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
this.driftDetection = builder.driftDetection;
this.predictionFunction = builder.predictionFunction;
this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
this.learningRatio = builder.learningRatio;
this.splitConfidence = builder.splitConfidence;
this.tieThreshold = builder.tieThreshold;
this.gracePeriod = builder.gracePeriod;
this.noAnomalyDetection = builder.noAnomalyDetection;
this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
this.unorderedRules = builder.unorderedRules;
this.numericObserver = builder.numericObserver;
this.voteType = builder.voteType;
}
/*
* Process
*/
@Override
public boolean process(ContentEvent event) {
InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
// predict
if (instanceEvent.isTesting()) {
this.predictOnInstance(instanceEvent);
}
// train
if (instanceEvent.isTraining()) {
this.trainOnInstance(instanceEvent);
}
return true;
}
/*
* Prediction
*/
private void predictOnInstance(InstanceContentEvent instanceEvent) {
double[] prediction = getVotesForInstance(instanceEvent.getInstance());
ResultContentEvent rce = newResultContentEvent(prediction, instanceEvent);
resultStream.put(rce);
}
/**
* Helper method to generate new ResultContentEvent based on an instance and
* its prediction result.
*
* @param prediction
* The predicted class label from the decision tree model.
* @param inEvent
* The associated instance content event
* @return ResultContentEvent to be sent into Evaluator PI or other
* destination PI.
*/
private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
inEvent.getClassId(), prediction, inEvent.isLastEvent());
rce.setClassifierIndex(this.processorId);
rce.setEvaluationIndex(inEvent.getEvaluationIndex());
return rce;
}
/**
* getVotesForInstance extension of the instance method getVotesForInstance in
* moa.classifier.java returns the prediction of the instance. Called in
* EvaluateModelRegression
*/
private double[] getVotesForInstance(Instance instance) {
ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
int numberOfRulesCovering = 0;
for (ActiveRule rule : ruleSet) {
if (rule.isCovering(instance) == true) {
numberOfRulesCovering++;
double[] vote = rule.getPrediction(instance);
double error = rule.getCurrentError();
errorWeightedVote.addVote(vote, error);
if (!this.unorderedRules) { // Ordered Rules Option.
break; // Only one rule cover the instance.
}
}
}
if (numberOfRulesCovering == 0) {
double[] vote = defaultRule.getPrediction(instance);
double error = defaultRule.getCurrentError();
errorWeightedVote.addVote(vote, error);
}
double[] weightedVote = errorWeightedVote.computeWeightedVote();
return weightedVote;
}
public ErrorWeightedVote newErrorWeightedVote() {
return voteType.getACopy();
}
/*
* Training
*/
private void trainOnInstance(InstanceContentEvent instanceEvent) {
this.trainOnInstanceImpl(instanceEvent.getInstance());
}
public void trainOnInstanceImpl(Instance instance) {
/**
* AMRules Algorithm
*
* //For each rule in the rule set //If rule covers the instance //if the
* instance is not an anomaly //Update Change Detection Tests //Compute
* prediction error //Call PHTest //If change is detected then //Remove rule
* //Else //Update sufficient statistics of rule //If number of examples in
* rule > Nmin //Expand rule //If ordered set then //break //If none of the
* rule covers the instance //Update sufficient statistics of default rule
* //If number of examples in default rule is multiple of Nmin //Expand
* default rule and add it to the set of rules //Reset the default rule
*/
boolean rulesCoveringInstance = false;
Iterator<ActiveRule> ruleIterator = this.ruleSet.iterator();
while (ruleIterator.hasNext()) {
ActiveRule rule = ruleIterator.next();
if (rule.isCovering(instance) == true) {
rulesCoveringInstance = true;
if (isAnomaly(instance, rule) == false) {
// Update Change Detection Tests
double error = rule.computeError(instance); // Use adaptive mode error
boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error);
if (changeDetected == true) {
ruleIterator.remove();
} else {
rule.updateStatistics(instance);
if (rule.getInstancesSeen() % this.gracePeriod == 0.0) {
if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) {
rule.split();
}
}
}
if (!this.unorderedRules)
break;
}
}
}
if (rulesCoveringInstance == false) {
defaultRule.updateStatistics(instance);
if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
(RuleActiveRegressionNode) defaultRule.getLearningNode(),
((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other
// branch
defaultRule.split();
defaultRule.setRuleNumberID(++ruleNumberID);
this.ruleSet.add(this.defaultRule);
defaultRule = newDefaultRule;
}
}
}
}
/**
* Method to verify if the instance is an anomaly.
*
* @param instance
* @param rule
* @return
*/
private boolean isAnomaly(Instance instance, ActiveRule rule) {
// AMRUles is equipped with anomaly detection. If on, compute the anomaly
// value.
boolean isAnomaly = false;
if (this.noAnomalyDetection == false) {
if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
isAnomaly = rule.isAnomaly(instance,
this.univariateAnomalyprobabilityThreshold,
this.multivariateAnomalyProbabilityThreshold,
this.anomalyNumInstThreshold);
}
}
return isAnomaly;
}
/*
* Create new rules
*/
// TODO check this after finish rule, LN
private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
ActiveRule r = newRule(ID);
if (node != null)
{
if (node.getPerceptron() != null)
{
r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
}
if (statistics == null)
{
double mean;
if (node.getNodeStatistics().getValue(0) > 0) {
mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
r.getLearningNode().getTargetMean().reset(mean, 1);
}
}
}
if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
{
double mean;
if (statistics[0] > 0) {
mean = statistics[1] / statistics[0];
((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
}
}
return r;
}
private ActiveRule newRule(int ID) {
ActiveRule r = new ActiveRule.Builder().
threshold(this.pageHinckleyThreshold).
alpha(this.pageHinckleyAlpha).
changeDetection(this.driftDetection).
predictionFunction(this.predictionFunction).
statistics(new double[3]).
learningRatio(this.learningRatio).
numericObserver(numericObserver).
id(ID).build();
return r;
}
/*
* Init processor
*/
@Override
public void onCreate(int id) {
this.processorId = id;
this.statistics = new double[] { 0.0, 0, 0 };
this.ruleNumberID = 0;
this.defaultRule = newRule(++this.ruleNumberID);
this.ruleSet = new LinkedList<ActiveRule>();
}
/*
* Clone processor
*/
@Override
public Processor newProcessor(Processor p) {
AMRulesRegressorProcessor oldProcessor = (AMRulesRegressorProcessor) p;
Builder builder = new Builder(oldProcessor);
AMRulesRegressorProcessor newProcessor = builder.build();
newProcessor.resultStream = oldProcessor.resultStream;
return newProcessor;
}
/*
* Output stream
*/
public void setResultStream(Stream resultStream) {
this.resultStream = resultStream;
}
public Stream getResultStream() {
return this.resultStream;
}
/*
* Others
*/
public boolean isRandomizable() {
return true;
}
/*
* Builder
*/
public static class Builder {
private int pageHinckleyThreshold;
private double pageHinckleyAlpha;
private boolean driftDetection;
private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
private boolean constantLearningRatioDecay;
private double learningRatio;
private double splitConfidence;
private double tieThreshold;
private int gracePeriod;
private boolean noAnomalyDetection;
private double multivariateAnomalyProbabilityThreshold;
private double univariateAnomalyprobabilityThreshold;
private int anomalyNumInstThreshold;
private boolean unorderedRules;
private FIMTDDNumericAttributeClassLimitObserver numericObserver;
private ErrorWeightedVote voteType;
private Instances dataset;
public Builder(Instances dataset) {
this.dataset = dataset;
}
public Builder(AMRulesRegressorProcessor processor) {
this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
this.driftDetection = processor.driftDetection;
this.predictionFunction = processor.predictionFunction;
this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
this.learningRatio = processor.learningRatio;
this.splitConfidence = processor.splitConfidence;
this.tieThreshold = processor.tieThreshold;
this.gracePeriod = processor.gracePeriod;
this.noAnomalyDetection = processor.noAnomalyDetection;
this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
this.unorderedRules = processor.unorderedRules;
this.numericObserver = processor.numericObserver;
this.voteType = processor.voteType;
}
public Builder threshold(int threshold) {
this.pageHinckleyThreshold = threshold;
return this;
}
public Builder alpha(double alpha) {
this.pageHinckleyAlpha = alpha;
return this;
}
public Builder changeDetection(boolean changeDetection) {
this.driftDetection = changeDetection;
return this;
}
public Builder predictionFunction(int predictionFunction) {
this.predictionFunction = predictionFunction;
return this;
}
public Builder constantLearningRatioDecay(boolean constantDecay) {
this.constantLearningRatioDecay = constantDecay;
return this;
}
public Builder learningRatio(double learningRatio) {
this.learningRatio = learningRatio;
return this;
}
public Builder splitConfidence(double splitConfidence) {
this.splitConfidence = splitConfidence;
return this;
}
public Builder tieThreshold(double tieThreshold) {
this.tieThreshold = tieThreshold;
return this;
}
public Builder gracePeriod(int gracePeriod) {
this.gracePeriod = gracePeriod;
return this;
}
public Builder noAnomalyDetection(boolean noAnomalyDetection) {
this.noAnomalyDetection = noAnomalyDetection;
return this;
}
public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
return this;
}
public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
return this;
}
public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
this.anomalyNumInstThreshold = anomalyNumInstThreshold;
return this;
}
public Builder unorderedRules(boolean unorderedRules) {
this.unorderedRules = unorderedRules;
return this;
}
public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
this.numericObserver = numericObserver;
return this;
}
public Builder voteType(ErrorWeightedVote voteType) {
this.voteType = voteType;
return this;
}
public AMRulesRegressorProcessor build() {
return new AMRulesRegressorProcessor(this);
}
}
}