blob: d61d8c8328d3b6bb15dbd836bfa6440a7d7dd726 [file] [log] [blame]
/* Project Knowledge Discovery from Data Streams, FCT LIAAD-INESC TEC,
*
* Contact: jgama@fep.up.pt
*/
package org.apache.samoa.moa.classifiers.core.attributeclassobservers;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.io.Serializable;
import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion;
import org.apache.samoa.moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.moa.core.DoubleVector;
import org.apache.samoa.moa.core.ObjectRepository;
import org.apache.samoa.moa.tasks.TaskMonitor;
public class FIMTDDNumericAttributeClassObserver extends BinaryTreeNumericAttributeClassObserver implements
NumericAttributeClassObserver {
private static final long serialVersionUID = 1L;
protected class Node implements Serializable {
private static final long serialVersionUID = 1L;
// The split point to use
public double cut_point;
// E-BST statistics
public DoubleVector leftStatistics = new DoubleVector();
public DoubleVector rightStatistics = new DoubleVector();
// Child nodes
public Node left;
public Node right;
public Node(double val, double label, double weight) {
this.cut_point = val;
this.leftStatistics.addToValue(0, 1);
this.leftStatistics.addToValue(1, label);
this.leftStatistics.addToValue(2, label * label);
}
/**
* Insert a new value into the tree, updating both the sum of values and sum of squared values arrays
*/
public void insertValue(double val, double label, double weight) {
// If the new value equals the value stored in a node, update
// the left (<=) node information
if (val == this.cut_point) {
this.leftStatistics.addToValue(0, 1);
this.leftStatistics.addToValue(1, label);
this.leftStatistics.addToValue(2, label * label);
} // If the new value is less than the value in a node, update the
// left distribution and send the value down to the left child node.
// If no left child exists, create one
else if (val <= this.cut_point) {
this.leftStatistics.addToValue(0, 1);
this.leftStatistics.addToValue(1, label);
this.leftStatistics.addToValue(2, label * label);
if (this.left == null) {
this.left = new Node(val, label, weight);
} else {
this.left.insertValue(val, label, weight);
}
} // If the new value is greater than the value in a node, update the
// right (>) distribution and send the value down to the right child node.
// If no right child exists, create one
else { // val > cut_point
this.rightStatistics.addToValue(0, 1);
this.rightStatistics.addToValue(1, label);
this.rightStatistics.addToValue(2, label * label);
if (this.right == null) {
this.right = new Node(val, label, weight);
} else {
this.right.insertValue(val, label, weight);
}
}
}
}
// Root node of the E-BST structure for this attribute
public Node root = null;
// Global variables for use in the FindBestSplit algorithm
double sumTotalLeft;
double sumTotalRight;
double sumSqTotalLeft;
double sumSqTotalRight;
double countRightTotal;
double countLeftTotal;
public void observeAttributeClass(double attVal, double classVal, double weight) {
if (!Double.isNaN(attVal)) {
if (this.root == null) {
this.root = new Node(attVal, classVal, weight);
} else {
this.root.insertValue(attVal, classVal, weight);
}
}
}
@Override
public double probabilityOfAttributeValueGivenClass(double attVal,
int classVal) {
// TODO: NaiveBayes broken until implemented
return 0.0;
}
@Override
public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist,
int attIndex, boolean binaryOnly) {
// Initialise global variables
sumTotalLeft = 0;
sumTotalRight = preSplitDist[1];
sumSqTotalLeft = 0;
sumSqTotalRight = preSplitDist[2];
countLeftTotal = 0;
countRightTotal = preSplitDist[0];
return searchForBestSplitOption(this.root, null, criterion, attIndex);
}
/**
* Implementation of the FindBestSplit algorithm from E.Ikonomovska et al.
*/
protected AttributeSplitSuggestion searchForBestSplitOption(Node currentNode,
AttributeSplitSuggestion currentBestOption, SplitCriterion criterion, int attIndex) {
// Return null if the current node is null or we have finished looking
// through all the possible splits
if (currentNode == null || countRightTotal == 0.0) {
return currentBestOption;
}
if (currentNode.left != null) {
currentBestOption = searchForBestSplitOption(currentNode.left, currentBestOption, criterion, attIndex);
}
sumTotalLeft += currentNode.leftStatistics.getValue(1);
sumTotalRight -= currentNode.leftStatistics.getValue(1);
sumSqTotalLeft += currentNode.leftStatistics.getValue(2);
sumSqTotalRight -= currentNode.leftStatistics.getValue(2);
countLeftTotal += currentNode.leftStatistics.getValue(0);
countRightTotal -= currentNode.leftStatistics.getValue(0);
double[][] postSplitDists = new double[][] { { countLeftTotal, sumTotalLeft, sumSqTotalLeft },
{ countRightTotal, sumTotalRight, sumSqTotalRight } };
double[] preSplitDist = new double[] { (countLeftTotal + countRightTotal), (sumTotalLeft + sumTotalRight),
(sumSqTotalLeft + sumSqTotalRight) };
double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
if ((currentBestOption == null) || (merit > currentBestOption.merit)) {
currentBestOption = new AttributeSplitSuggestion(
new NumericAttributeBinaryTest(attIndex,
currentNode.cut_point, true), postSplitDists, merit);
}
if (currentNode.right != null) {
currentBestOption = searchForBestSplitOption(currentNode.right, currentBestOption, criterion, attIndex);
}
sumTotalLeft -= currentNode.leftStatistics.getValue(1);
sumTotalRight += currentNode.leftStatistics.getValue(1);
sumSqTotalLeft -= currentNode.leftStatistics.getValue(2);
sumSqTotalRight += currentNode.leftStatistics.getValue(2);
countLeftTotal -= currentNode.leftStatistics.getValue(0);
countRightTotal += currentNode.leftStatistics.getValue(0);
return currentBestOption;
}
/**
* A method to remove all nodes in the E-BST in which it and all it's children represent 'bad' split points
*/
public void removeBadSplits(SplitCriterion criterion, double lastCheckRatio, double lastCheckSDR, double lastCheckE) {
removeBadSplitNodes(criterion, this.root, lastCheckRatio, lastCheckSDR, lastCheckE);
}
/**
* Recursive method that first checks all of a node's children before deciding if it is 'bad' and may be removed
*/
private boolean removeBadSplitNodes(SplitCriterion criterion, Node currentNode, double lastCheckRatio,
double lastCheckSDR, double lastCheckE) {
boolean isBad = false;
if (currentNode == null) {
return true;
}
if (currentNode.left != null) {
isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
}
if (currentNode.right != null && isBad) {
isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
}
if (isBad) {
double[][] postSplitDists = new double[][] {
{ currentNode.leftStatistics.getValue(0), currentNode.leftStatistics.getValue(1),
currentNode.leftStatistics.getValue(2) },
{ currentNode.rightStatistics.getValue(0), currentNode.rightStatistics.getValue(1),
currentNode.rightStatistics.getValue(2) } };
double[] preSplitDist = new double[] {
(currentNode.leftStatistics.getValue(0) + currentNode.rightStatistics.getValue(0)),
(currentNode.leftStatistics.getValue(1) + currentNode.rightStatistics.getValue(1)),
(currentNode.leftStatistics.getValue(2) + currentNode.rightStatistics.getValue(2)) };
double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
if ((merit / lastCheckSDR) < (lastCheckRatio - (2 * lastCheckE))) {
currentNode = null;
return true;
}
}
return false;
}
@Override
public void getDescription(StringBuilder sb, int indent) {
// TODO Auto-generated method stub
}
@Override
protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
// TODO Auto-generated method stub
}
}