blob: ea7e53d7c4d6fdfd0002007095d852cd8fa0a901 [file] [log] [blame]
package org.apache.samoa.learners.classifiers.trees;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import com.google.common.collect.ImmutableSet;
import java.util.Set;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.learners.AdaptiveLearner;
import org.apache.samoa.learners.ClassificationLearner;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.topology.Stream;
import org.apache.samoa.topology.TopologyBuilder;
import com.github.javacliparser.ClassOption;
import com.github.javacliparser.Configurable;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
/**
* Vertical Hoeffding Tree.
* <p/>
* Vertical Hoeffding Tree (VHT) classifier is a distributed classifier that utilizes vertical parallelism on top of
* Very Fast Decision Tree (VFDT) classifier.
*
* @author Arinto Murdopo
*/
public final class VerticalHoeffdingTree implements ClassificationLearner, AdaptiveLearner, Configurable {
private static final long serialVersionUID = -4937416312929984057L;
public ClassOption numericEstimatorOption = new ClassOption("numericEstimator",
'n', "Numeric estimator to use.", NumericAttributeClassObserver.class,
"GaussianNumericAttributeClassObserver");
public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator",
'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class,
"NominalAttributeClassObserver");
public ClassOption splitCriterionOption = new ClassOption("splitCriterion",
's', "Split criterion to use.", SplitCriterion.class,
"InfoGainSplitCriterion");
public FloatOption splitConfidenceOption = new FloatOption(
"splitConfidence",
'c',
"The allowable error in split decision, values closer to 0 will take longer to decide.",
0.0000001, 0.0, 1.0);
public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
't', "Threshold below which a split will be forced to break ties.",
0.05, 0.0, 1.0);
public IntOption gracePeriodOption = new IntOption(
"gracePeriod",
'g',
"The number of instances a leaf should observe between split attempts.",
200, 0, Integer.MAX_VALUE);
public IntOption parallelismHintOption = new IntOption(
"parallelismHint",
'p',
"The number of local statistics PI to do distributed computation",
1, 1, Integer.MAX_VALUE);
public IntOption timeOutOption = new IntOption(
"timeOut",
'o',
"The duration to wait all distributed computation results from local statistics PI",
30, 1, Integer.MAX_VALUE);
public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b',
"Only allow binary splits.");
private Stream resultStream;
private FilterProcessor filterProc;
@Override
public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {
this.filterProc = new FilterProcessor.Builder(dataset)
.build();
topologyBuilder.addProcessor(filterProc, parallelism);
Stream filterStream = topologyBuilder.createStream(filterProc);
this.filterProc.setOutputStream(filterStream);
ModelAggregatorProcessor modelAggrProc = new ModelAggregatorProcessor.Builder(dataset)
.splitCriterion((SplitCriterion) this.splitCriterionOption.getValue())
.splitConfidence(splitConfidenceOption.getValue())
.tieThreshold(tieThresholdOption.getValue())
.gracePeriod(gracePeriodOption.getValue())
.parallelismHint(parallelismHintOption.getValue())
.timeOut(timeOutOption.getValue())
.changeDetector(this.getChangeDetector())
.build();
topologyBuilder.addProcessor(modelAggrProc, parallelism);
topologyBuilder.connectInputShuffleStream(filterStream, modelAggrProc);
this.resultStream = topologyBuilder.createStream(modelAggrProc);
modelAggrProc.setResultStream(resultStream);
Stream attributeStream = topologyBuilder.createStream(modelAggrProc);
modelAggrProc.setAttributeStream(attributeStream);
Stream controlStream = topologyBuilder.createStream(modelAggrProc);
modelAggrProc.setControlStream(controlStream);
LocalStatisticsProcessor locStatProc = new LocalStatisticsProcessor.Builder()
.splitCriterion((SplitCriterion) this.splitCriterionOption.getValue())
.binarySplit(binarySplitsOption.isSet())
.nominalClassObserver((AttributeClassObserver) this.nominalEstimatorOption.getValue())
.numericClassObserver((AttributeClassObserver) this.numericEstimatorOption.getValue())
.build();
topologyBuilder.addProcessor(locStatProc, parallelismHintOption.getValue());
topologyBuilder.connectInputKeyStream(attributeStream, locStatProc);
topologyBuilder.connectInputAllStream(controlStream, locStatProc);
Stream computeStream = topologyBuilder.createStream(locStatProc);
locStatProc.setComputationResultStream(computeStream);
topologyBuilder.connectInputAllStream(computeStream, modelAggrProc);
}
@Override
public Processor getInputProcessor() {
return this.filterProc;
}
@Override
public Set<Stream> getResultStreams() {
return ImmutableSet.of(this.resultStream);
}
protected ChangeDetector changeDetector;
@Override
public ChangeDetector getChangeDetector() {
return this.changeDetector;
}
@Override
public void setChangeDetector(ChangeDetector cd) {
this.changeDetector = cd;
}
static class LearningNodeIdGenerator {
// TODO: add code to warn user of when value reaches Long.MAX_VALUES
private static long id = 0;
static synchronized long generate() {
return id++;
}
}
}