| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.samoa.learners.classifiers.ensemble; |
| |
| import com.github.javacliparser.ClassOption; |
| import com.github.javacliparser.Configurable; |
| import com.github.javacliparser.FlagOption; |
| import com.github.javacliparser.FloatOption; |
| import com.github.javacliparser.IntOption; |
| import com.google.common.collect.ImmutableSet; |
| import org.apache.samoa.core.Processor; |
| import org.apache.samoa.instances.Instances; |
| import org.apache.samoa.learners.ClassificationLearner; |
| import org.apache.samoa.learners.Learner; |
| import org.apache.samoa.learners.classifiers.trees.BoostVHTActiveLearningNode.SplittingOption; |
| import org.apache.samoa.learners.classifiers.trees.LocalStatisticsProcessor; |
| import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; |
| import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; |
| import org.apache.samoa.moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver; |
| import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver; |
| import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; |
| import org.apache.samoa.topology.Stream; |
| import org.apache.samoa.topology.TopologyBuilder; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import java.util.Set; |
| |
| /** |
| * The Bagging Classifier by Oza and Russell. |
| */ |
| public class BoostVHT implements ClassificationLearner, Configurable { |
| |
| /** The Constant serialVersionUID. */ |
| private static final long serialVersionUID = -7523211543185584536L; |
| |
| private static final Logger logger = LoggerFactory.getLogger(BoostVHT.class); |
| |
| public ClassOption numericEstimatorOption = new ClassOption("numericEstimator", |
| 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class, |
| "GaussianNumericAttributeClassObserver"); |
| |
| public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", |
| 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, |
| "NominalAttributeClassObserver"); |
| |
| public ClassOption splitCriterionOption = new ClassOption("splitCriterion", |
| 'r', "Split criterion to use.", SplitCriterion.class, |
| "InfoGainSplitCriterion"); |
| |
| public FloatOption splitConfidenceOption = new FloatOption("splitConfidence", 'c', |
| "The allowable error in split decision, values closer to 0 will take longer to decide.", |
| 0.0000001, 0.0, 1.0); |
| |
| public FloatOption tieThresholdOption = new FloatOption("tieThreshold", |
| 't', "Threshold below which a split will be forced to break ties.", |
| 0.05, 0.0, 1.0); |
| |
| public IntOption gracePeriodOption = new IntOption("gracePeriod", 'g', |
| "The number of instances a leaf should observe between split attempts.", |
| 200, 0, Integer.MAX_VALUE); |
| |
| public IntOption timeOutOption = new IntOption("timeOut", 'o', |
| "The duration to wait all distributed computation results from local statistics PI, in miliseconds", |
| Integer.MAX_VALUE, 1, Integer.MAX_VALUE); |
| |
| public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', |
| "Only allow binary splits."); |
| |
| public FlagOption splittingOption = new FlagOption("keepInstanceWhileSplitting",'q', |
| "Keep instances in a buffer while splitting"); |
| |
| public IntOption maxBufferSizeOption = new IntOption("maxBufferSizeWhileSplitting",'z', |
| "Maximum buffer size while splitting, use in conjunction with 'q' option. Size 0 means we don't use buffer while splitting", |
| 0, 0, Integer.MAX_VALUE); |
| |
| /** The ensemble size option. */ |
| public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', |
| "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); |
| |
| public IntOption seedOption = new IntOption("seed", 'u', |
| "the seed for the rng.", (int) System.currentTimeMillis()); |
| |
| /** The Model Aggregator boosting processor. */ |
| private BoostVHTProcessor boostVHTProcessor; |
| |
| /** The result stream. */ |
| protected Stream resultStream; |
| |
| /** The attribute stream. */ |
| protected Stream attributeStream; |
| |
| /** The control stream. */ |
| protected Stream controlStream; |
| |
| /** The compute stream. */ |
| protected Stream computeStream; |
| |
| /** The dataset. */ |
| private Instances dataset; |
| |
| protected int parallelism; |
| |
| //for SAMMME |
| public IntOption numberOfClassesOption = new IntOption("numberOfClasses", 'k', |
| "The number of classes.", 2, 2, Integer.MAX_VALUE); |
| |
| /** |
| * Sets the layout. |
| */ |
| protected void setLayout() { |
| |
| int ensembleSize = this.ensembleSizeOption.getValue(); |
| |
| // Set parameters for BoostVHT processor, and the BoostMA processors within. |
| try { |
| boostVHTProcessor = new BoostVHTProcessor.Builder(dataset) |
| .ensembleSize(this.ensembleSizeOption.getValue()) |
| .numberOfClasses(this.numberOfClassesOption.getValue()) |
| .splitCriterion( |
| (SplitCriterion) ClassOption.createObject(this.splitCriterionOption.getValueAsCLIString(), |
| this.splitCriterionOption.getRequiredType())) |
| .splitConfidence(this.splitConfidenceOption.getValue()) |
| .tieThreshold(this.tieThresholdOption.getValue()) |
| .gracePeriod(this.gracePeriodOption.getValue()) |
| .parallelismHint(this.ensembleSizeOption.getValue()) |
| .timeOut(this.timeOutOption.getValue()) |
| .splittingOption(this.splittingOption.isSet() ? SplittingOption.KEEP: SplittingOption.THROW_AWAY) |
| .maxBufferSize(this.maxBufferSizeOption.getValue()) |
| .seed(this.seedOption.getValue()) |
| .build(); |
| } catch (Exception e) { |
| e.printStackTrace(); |
| } |
| |
| //add Boosting Model Aggregator Processor to the topology |
| this.topologyBuilder.addProcessor(boostVHTProcessor, 1); |
| |
| |
| // Streams |
| attributeStream = this.topologyBuilder.createStream(boostVHTProcessor); |
| controlStream = this.topologyBuilder.createStream(boostVHTProcessor); |
| |
| //local statistics processor. |
| LocalStatisticsProcessor locStatProcessor = new LocalStatisticsProcessor.Builder() |
| .splitCriterion((SplitCriterion) this.splitCriterionOption.getValue()) |
| .binarySplit(this.binarySplitsOption.isSet()) |
| .nominalClassObserver((AttributeClassObserver) this.nominalEstimatorOption.getValue()) |
| .numericClassObserver((AttributeClassObserver) this.numericEstimatorOption.getValue()) |
| .build(); |
| |
| this.topologyBuilder.addProcessor(locStatProcessor, ensembleSize); |
| |
| this.topologyBuilder.connectInputKeyStream(attributeStream, locStatProcessor); |
| this.topologyBuilder.connectInputAllStream(controlStream, locStatProcessor); |
| |
| |
| //local statistics result stream |
| computeStream = this.topologyBuilder.createStream(locStatProcessor); |
| locStatProcessor.setComputationResultStream(computeStream); |
| this.topologyBuilder.connectInputAllStream(computeStream, boostVHTProcessor); |
| |
| //prediction is computed in boostVHTProcessor |
| resultStream = this.topologyBuilder.createStream(boostVHTProcessor); |
| |
| //set the out streams of the BoostVHTProcessor |
| boostVHTProcessor.setResultStream(resultStream); |
| boostVHTProcessor.setAttributeStream(attributeStream); |
| boostVHTProcessor.setControlStream(controlStream); |
| } |
| |
| /** The topologyBuilder. */ |
| private TopologyBuilder topologyBuilder; |
| |
| /* |
| * (non-Javadoc) |
| * |
| * @see samoa.classifiers.Classifier#init(samoa.engines.Engine, |
| * samoa.core.Stream, weka.core.Instances) |
| */ |
| |
| @Override |
| public void init(TopologyBuilder builder, Instances dataset, int parallelism) { |
| this.topologyBuilder = builder; |
| this.dataset = dataset; |
| this.parallelism = parallelism; |
| this.setLayout(); |
| } |
| |
| @Override |
| public Processor getInputProcessor() { |
| return boostVHTProcessor; |
| } |
| |
| /* |
| * (non-Javadoc) |
| * |
| * @see samoa.learners.Learner#getResultStreams() |
| */ |
| @Override |
| public Set<Stream> getResultStreams() { |
| return ImmutableSet.of(this.resultStream); |
| } |
| } |