blob: 70dc4f89d59c349b81c50e9f855ca0d0323cc76d [file] [log] [blame]
package org.apache.samoa.streams.generators;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Random;
import org.apache.samoa.instances.Attribute;
import org.apache.samoa.instances.DenseInstance;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.instances.InstancesHeader;
import org.apache.samoa.moa.core.FastVector;
import org.apache.samoa.moa.core.InstanceExample;
import org.apache.samoa.moa.core.ObjectRepository;
import org.apache.samoa.moa.options.AbstractOptionHandler;
import org.apache.samoa.moa.tasks.TaskMonitor;
import org.apache.samoa.streams.InstanceStream;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
/**
* Stream generator for a stream based on a randomly generated tree..
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class RandomTreeGenerator extends AbstractOptionHandler implements InstanceStream {
@Override
public String getPurposeString() {
return "Generates a stream based on a randomly generated tree.";
}
private static final long serialVersionUID = 1L;
public IntOption treeRandomSeedOption = new IntOption("treeRandomSeed",
'r', "Seed for random generation of tree.", 1);
public IntOption instanceRandomSeedOption = new IntOption(
"instanceRandomSeed", 'i',
"Seed for random generation of instances.", 1);
public IntOption numClassesOption = new IntOption("numClasses", 'c',
"The number of classes to generate.", 2, 2, Integer.MAX_VALUE);
public IntOption numNominalsOption = new IntOption("numNominals", 'o',
"The number of nominal attributes to generate.", 5, 0,
Integer.MAX_VALUE);
public IntOption numNumericsOption = new IntOption("numNumerics", 'u',
"The number of numeric attributes to generate.", 5, 0,
Integer.MAX_VALUE);
public IntOption numValsPerNominalOption = new IntOption(
"numValsPerNominal", 'v',
"The number of values to generate per nominal attribute.", 5, 2,
Integer.MAX_VALUE);
public IntOption maxTreeDepthOption = new IntOption("maxTreeDepth", 'd',
"The maximum depth of the tree concept.", 5, 0, Integer.MAX_VALUE);
public IntOption firstLeafLevelOption = new IntOption(
"firstLeafLevel",
'l',
"The first level of the tree above maxTreeDepth that can have leaves.",
3, 0, Integer.MAX_VALUE);
public FloatOption leafFractionOption = new FloatOption("leafFraction",
'f',
"The fraction of leaves per level from firstLeafLevel onwards.",
0.15, 0.0, 1.0);
protected static class Node implements Serializable {
private static final long serialVersionUID = 1L;
public int classLabel;
public int splitAttIndex;
public double splitAttValue;
public Node[] children;
}
protected Node treeRoot;
protected InstancesHeader streamHeader;
protected Random instanceRandom;
@Override
public void prepareForUseImpl(TaskMonitor monitor,
ObjectRepository repository) {
monitor.setCurrentActivity("Preparing random tree...", -1.0);
generateHeader();
generateRandomTree();
restart();
}
@Override
public long estimatedRemainingInstances() {
return -1;
}
@Override
public boolean isRestartable() {
return true;
}
@Override
public void restart() {
this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
}
@Override
public InstancesHeader getHeader() {
return this.streamHeader;
}
@Override
public boolean hasMoreInstances() {
return true;
}
@Override
public InstanceExample nextInstance() {
double[] attVals = new double[this.numNominalsOption.getValue()
+ this.numNumericsOption.getValue()];
InstancesHeader header = getHeader();
Instance inst = new DenseInstance(header.numAttributes());
for (int i = 0; i < attVals.length; i++) {
attVals[i] = i < this.numNominalsOption.getValue() ? this.instanceRandom.nextInt(this.numValsPerNominalOption
.getValue())
: this.instanceRandom.nextDouble();
inst.setValue(i, attVals[i]);
}
inst.setDataset(header);
inst.setClassValue(classifyInstance(this.treeRoot, attVals));
return new InstanceExample(inst);
}
protected int classifyInstance(Node node, double[] attVals) {
if (node.children == null) {
return node.classLabel;
}
if (node.splitAttIndex < this.numNominalsOption.getValue()) {
return classifyInstance(
node.children[(int) attVals[node.splitAttIndex]], attVals);
}
return classifyInstance(
node.children[attVals[node.splitAttIndex] < node.splitAttValue ? 0
: 1], attVals);
}
protected void generateHeader() {
FastVector<Attribute> attributes = new FastVector<>();
FastVector<String> nominalAttVals = new FastVector<>();
for (int i = 0; i < this.numValsPerNominalOption.getValue(); i++) {
nominalAttVals.addElement("value" + (i + 1));
}
for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
attributes.addElement(new Attribute("nominal" + (i + 1),
nominalAttVals));
}
for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
attributes.addElement(new Attribute("numeric" + (i + 1)));
}
FastVector<String> classLabels = new FastVector<>();
for (int i = 0; i < this.numClassesOption.getValue(); i++) {
classLabels.addElement("class" + (i + 1));
}
attributes.addElement(new Attribute("class", classLabels));
this.streamHeader = new InstancesHeader(new Instances(
getCLICreationString(InstanceStream.class), attributes, 0));
this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
}
protected void generateRandomTree() {
Random treeRand = new Random(this.treeRandomSeedOption.getValue());
ArrayList<Integer> nominalAttCandidates = new ArrayList<>(
this.numNominalsOption.getValue());
for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
nominalAttCandidates.add(i);
}
double[] minNumericVals = new double[this.numNumericsOption.getValue()];
double[] maxNumericVals = new double[this.numNumericsOption.getValue()];
for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
minNumericVals[i] = 0.0;
maxNumericVals[i] = 1.0;
}
this.treeRoot = generateRandomTreeNode(0, nominalAttCandidates,
minNumericVals, maxNumericVals, treeRand);
}
protected Node generateRandomTreeNode(int currentDepth,
ArrayList<Integer> nominalAttCandidates, double[] minNumericVals,
double[] maxNumericVals, Random treeRand) {
if ((currentDepth >= this.maxTreeDepthOption.getValue())
|| ((currentDepth >= this.firstLeafLevelOption.getValue()) && (this.leafFractionOption.getValue() >= (1.0 - treeRand
.nextDouble())))) {
Node leaf = new Node();
leaf.classLabel = treeRand.nextInt(this.numClassesOption.getValue());
return leaf;
}
Node node = new Node();
int chosenAtt = treeRand.nextInt(nominalAttCandidates.size()
+ this.numNumericsOption.getValue());
if (chosenAtt < nominalAttCandidates.size()) {
node.splitAttIndex = nominalAttCandidates.get(chosenAtt);
node.children = new Node[this.numValsPerNominalOption.getValue()];
ArrayList<Integer> newNominalCandidates = new ArrayList<>(
nominalAttCandidates);
newNominalCandidates.remove(new Integer(node.splitAttIndex));
newNominalCandidates.trimToSize();
for (int i = 0; i < node.children.length; i++) {
node.children[i] = generateRandomTreeNode(currentDepth + 1,
newNominalCandidates, minNumericVals, maxNumericVals,
treeRand);
}
} else {
int numericIndex = chosenAtt - nominalAttCandidates.size();
node.splitAttIndex = this.numNominalsOption.getValue()
+ numericIndex;
double minVal = minNumericVals[numericIndex];
double maxVal = maxNumericVals[numericIndex];
node.splitAttValue = ((maxVal - minVal) * treeRand.nextDouble())
+ minVal;
node.children = new Node[2];
double[] newMaxVals = maxNumericVals.clone();
newMaxVals[numericIndex] = node.splitAttValue;
node.children[0] = generateRandomTreeNode(currentDepth + 1,
nominalAttCandidates, minNumericVals, newMaxVals, treeRand);
double[] newMinVals = minNumericVals.clone();
newMinVals[numericIndex] = node.splitAttValue;
node.children[1] = generateRandomTreeNode(currentDepth + 1,
nominalAttCandidates, newMinVals, maxNumericVals, treeRand);
}
return node;
}
@Override
public void getDescription(StringBuilder sb, int indent) {
// TODO Auto-generated method stub
}
}