blob: fd93db165a7f971b803f1df1aaa7bde973145f83 [file] [log] [blame]
package com.yahoo.labs.samoa.learners.classifiers.trees;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2013 Yahoo! Inc.
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest;
import com.yahoo.labs.samoa.moa.core.AutoExpandVector;
import com.yahoo.labs.samoa.instances.Instance;
/**
* SplitNode represents the node that contains one or more questions in the decision tree model,
* in order to route the instances into the correct leaf.
* @author Arinto Murdopo
*
*/
public class SplitNode extends Node {
private static final long serialVersionUID = -7380795529928485792L;
private final AutoExpandVector<Node> children;
protected final InstanceConditionalTest splitTest;
public SplitNode(InstanceConditionalTest splitTest,
double[] classObservation) {
super(classObservation);
this.children = new AutoExpandVector<>();
this.splitTest = splitTest;
}
@Override
FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch) {
int childIndex = instanceChildIndex(inst);
if(childIndex >= 0){
Node child = getChild(childIndex);
if(child != null){
return child.filterInstanceToLeaf(inst, this, childIndex);
}
return new FoundNode(null, this, childIndex);
}
return new FoundNode(this, parent, parentBranch);
}
@Override
boolean isLeaf() {
return false;
}
@Override
double[] getClassVotes(Instance inst, ModelAggregatorProcessor vht) {
return this.observedClassDistribution.getArrayCopy();
}
/**
* Method to return the number of children of this split node
* @return number of children
*/
int numChildren(){
return this.children.size();
}
/**
* Method to set the children in a specific index of the SplitNode with the appropriate child
* @param index Index of the child in the SplitNode
* @param child The child node
*/
void setChild(int index, Node child){
if ((this.splitTest.maxBranches() >= 0)
&& (index >= this.splitTest.maxBranches())) {
throw new IndexOutOfBoundsException();
}
this.children.set(index, child);
}
/**
* Method to get the child node given the index
* @param index The child node index
* @return The child node in the given index
*/
Node getChild(int index){
return this.children.get(index);
}
/**
* Method to route the instance using this split node
* @param inst The routed instance
* @return The index of the branch where the instance is routed
*/
int instanceChildIndex(Instance inst){
return this.splitTest.branchForInstance(inst);
}
}